diff --git a/meson.build b/meson.build index ca28c0a1..a10598f9 100644 --- a/meson.build +++ b/meson.build @@ -1,6 +1,7 @@ project('x86-simd-sort', 'cpp', - version : '1.0.0', - license : 'BSD 3-clause') + version : '2.0.0', + license : 'BSD 3-clause', + default_options : ['cpp_std=c++17']) cpp = meson.get_compiler('cpp') src = include_directories('src') bench = include_directories('benchmarks') diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index b5202f46..2cdb45e7 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -377,7 +377,8 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, +template <> +int64_t replace_nan_with_inf>(uint16_t *arr, int64_t arrsize) { int64_t nan_count = 0; @@ -396,77 +397,66 @@ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, return nan_count; } -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = 0xFFFF; - nan_count -= 1; - } -} - template <> bool is_a_nan(uint16_t elem) { return (elem & 0x7c00) == 0x7c00; } +/* Specialized template function for 16-bit qsort_ funcs*/ template <> -void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void qsort_>(int16_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - qselect_16bit_, int16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_16bit_>(arr, left, right, maxiters); } template <> -void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void qsort_>(uint16_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - qselect_16bit_, uint16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_16bit_>(arr, left, right, maxiters); } -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_16bit_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf, uint16_t>( + arr, arrsize); + qsort_16bit_, uint16_t>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } +/* Specialized template function for 16-bit qselect_ funcs*/ template <> -void avx512_qsort(int16_t *arr, int64_t arrsize) +void qselect_>( + int16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_16bit_, int16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_16bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qsort(uint16_t *arr, int64_t arrsize) +void qselect_>( + uint16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_16bit_, uint16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_16bit_>(arr, k, left, right, maxiters); } -void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) +void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_16bit_, uint16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + if (indx_last_elem >= k) { + qselect_16bit_, uint16_t>( + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } - #endif // AVX512_QSORT_16BIT diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index a0dd7f7e..054e4b26 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -256,6 +256,15 @@ struct zmm_vector { { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x0001 << size) - 0x0001; + } + template + static opmask_t fpclass(zmm_t x) + { + return _mm512_fpclass_ps_mask(x, type); + } template static ymm_t i64gather(__m512i index, void const *base) { @@ -279,6 +288,10 @@ struct zmm_vector { { return _mm512_mask_compressstoreu_ps(mem, mask, x); } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm512_maskz_loadu_ps(mask, mem); + } static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_ps(x, mask, mem); @@ -689,95 +702,53 @@ static void qselect_32bit_(type_t *arr, qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - while (arrsize > 0) { - if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } - __m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr); - __mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT); - arr += 16; - arrsize -= 16; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nanf("1"); - nan_count -= 1; - } -} - +/* Specialized template function for 32-bit qselect_ funcs*/ template <> -void avx512_qselect(int32_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>( + int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_32bit_, int32_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_32bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(uint32_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>( + uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_32bit_, uint32_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_32bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) +void qselect_>( + float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_32bit_, float>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + qselect_32bit_>(arr, k, left, right, maxiters); } +/* Specialized template function for 32-bit qsort_ funcs*/ template <> -void avx512_qsort(int32_t *arr, int64_t arrsize) +void qsort_>(int32_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - qsort_32bit_, int32_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_32bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(uint32_t *arr, int64_t arrsize) +void qsort_>(uint32_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - qsort_32bit_, uint32_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_32bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(float *arr, int64_t arrsize) +void qsort_>(float *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_32bit_, float>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } + qsort_32bit_>(arr, left, right, maxiters); } - #endif //AVX512_QSORT_32BIT diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 3626ab63..c6499a6c 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -344,89 +344,25 @@ static void argselect_64bit_(type_t *arr, arr, arg, pos, pivot_index, right, max_iters - 1); } -template -bool has_nan(type_t *arr, int64_t arrsize) -{ - using opmask_t = typename vtype::opmask_t; - using zmm_t = typename vtype::zmm_t; - bool found_nan = false; - opmask_t loadmask = 0xFF; - zmm_t in; - while (arrsize > 0) { - if (arrsize < vtype::numlanes) { - loadmask = (0x01 << arrsize) - 0x01; - in = vtype::maskz_loadu(loadmask, arr); - } - else { - in = vtype::loadu(arr); - } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); - arr += vtype::numlanes; - arrsize -= vtype::numlanes; - if (nanmask != 0x00) { - found_nan = true; - break; - } - } - return found_nan; -} - /* argsort methods for 32-bit and 64-bit dtypes */ template void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize) { + using vectype = typename std::conditional, + zmm_vector>::type; if (arrsize > 1) { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argsort_withnan(arr, arg, 0, arrsize); - } - else { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + if constexpr (std::is_floating_point_v) { + if (has_nan(arr, arrsize)) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } } - } -} - -template <> -void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - argsort_64bit_>( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } -template <> -void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argsort_withnan(arr, arg, 0, arrsize); - } - else { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } - } -} - template std::vector avx512_argsort(T *arr, int64_t arrsize) { @@ -440,58 +376,22 @@ std::vector avx512_argsort(T *arr, int64_t arrsize) template void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} + using vectype = typename std::conditional, + zmm_vector>::type; -template <> -void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argselect_withnan(arr, arg, k, 0, arrsize); + if constexpr (std::is_floating_point_v) { + if (has_nan(arr, arrsize)) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } } - else { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } - } -} - -template <> -void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ - if (arrsize > 1) { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ - if (arrsize > 1) { - argselect_64bit_>( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } -template <> -void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ - if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - } - else { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } - } -} - template std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) { diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index d12684c1..fbd4a88f 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -40,14 +40,8 @@ struct ymm_vector { return _mm256_set1_ps(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -71,6 +65,10 @@ struct ymm_vector { { return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x01 << size) - 0x01; + } template static opmask_t fpclass(zmm_t x) { @@ -89,7 +87,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_ps((float*) mem); + return _mm256_loadu_ps((float *)mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -125,16 +123,22 @@ struct ymm_vector { } static type_t reducemax(zmm_t v) { - __m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps (v, 1)); - __m128 v64 = _mm_max_ps(v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128 v32 = _mm_max_ps(v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), + _mm256_extractf32x4_ps(v, 1)); + __m128 v64 = _mm_max_ps( + v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128 v32 = _mm_max_ps( + v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); return _mm_cvtss_f32(v32); } static type_t reducemin(zmm_t v) { - __m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps(v, 1)); - __m128 v64 = _mm_min_ps(v128, _mm_shuffle_ps(v128, v128,_MM_SHUFFLE(1, 0, 3, 2))); - __m128 v32 = _mm_min_ps(v64, _mm_shuffle_ps(v64, v64,_MM_SHUFFLE(0, 0, 0, 1))); + __m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), + _mm256_extractf32x4_ps(v, 1)); + __m128 v64 = _mm_min_ps( + v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128 v32 = _mm_min_ps( + v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); return _mm_cvtss_f32(v32); } static zmm_t set1(type_t v) @@ -156,7 +160,7 @@ struct ymm_vector { } static void storeu(void *mem, zmm_t x) { - _mm256_storeu_ps((float*)mem, x); + _mm256_storeu_ps((float *)mem, x); } }; template <> @@ -180,14 +184,8 @@ struct ymm_vector { return _mm256_set1_epi32(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -224,7 +222,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_si256((__m256i*) mem); + return _mm256_loadu_si256((__m256i *)mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -260,16 +258,22 @@ struct ymm_vector { } static type_t reducemax(zmm_t v) { - __m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_max_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_max_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_max_epu32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_max_epu32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static type_t reducemin(zmm_t v) { - __m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_min_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_min_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_min_epu32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_min_epu32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static zmm_t set1(type_t v) @@ -285,7 +289,7 @@ struct ymm_vector { } static void storeu(void *mem, zmm_t x) { - _mm256_storeu_si256((__m256i*) mem, x); + _mm256_storeu_si256((__m256i *)mem, x); } }; template <> @@ -309,14 +313,8 @@ struct ymm_vector { return _mm256_set1_epi32(type_max()); } // TODO: this should broadcast bits as is? - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -353,7 +351,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_si256((__m256i*) mem); + return _mm256_loadu_si256((__m256i *)mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -389,16 +387,22 @@ struct ymm_vector { } static type_t reducemax(zmm_t v) { - __m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_max_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_max_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_max_epi32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_max_epi32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static type_t reducemin(zmm_t v) { - __m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_min_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_min_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_min_epi32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_min_epi32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static zmm_t set1(type_t v) @@ -414,7 +418,7 @@ struct ymm_vector { } static void storeu(void *mem, zmm_t x) { - _mm256_storeu_si256((__m256i*) mem, x); + _mm256_storeu_si256((__m256i *)mem, x); } }; template <> @@ -439,14 +443,8 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } // TODO: this should broadcast bits as is? - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -563,14 +561,8 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -675,14 +667,8 @@ struct zmm_vector { return _mm512_set1_pd(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -703,6 +689,10 @@ struct zmm_vector { { return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x01 << size) - 0x01; + } template static opmask_t fpclass(zmm_t x) { @@ -773,30 +763,7 @@ struct zmm_vector { _mm512_storeu_pd(mem, x); } }; -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); - __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); - arr += 8; - arrsize -= 8; - } - return nan_count; -} -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; - } -} /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg @@ -808,16 +775,12 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge( - zmm, - vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), - 0xCC); + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); zmm = cmp_merge( - zmm, - vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), - 0xCC); + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); return zmm; diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index f721f5c8..16f8d354 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -439,34 +439,21 @@ void qsort_64bit_(type1_t *keys, } } -template <> -void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) +template +void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, zmm_vector>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_qsort_kv(uint64_t *keys, - uint64_t *indexes, - int64_t arrsize) -{ - if (arrsize > 1) { - qsort_64bit_, zmm_vector>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) -{ - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(keys, arrsize); - qsort_64bit_, zmm_vector>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(keys, arrsize, nan_count); + if constexpr (std::is_floating_point_v) { + int64_t nan_count + = replace_nan_with_inf>(keys, arrsize); + qsort_64bit_, zmm_vector>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(keys, arrsize, nan_count); + } + else { + qsort_64bit_, zmm_vector>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } } } #endif // AVX512_QSORT_64BIT_KV diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index d59a1788..626e672e 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -783,72 +783,53 @@ static void qselect_64bit_(type_t *arr, qselect_64bit_(arr, pos, pivot_index, right, max_iters - 1); } +/* Specialized template function for 64-bit qselect_ funcs*/ template <> -void avx512_qselect(int64_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>( + int64_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_64bit_, int64_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_64bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(uint64_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>( + uint64_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_64bit_, uint64_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_64bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(double *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>( + double *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_64bit_, double>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + qselect_64bit_>(arr, k, left, right, maxiters); } +/* Specialized template function for 64-bit qsort_ funcs*/ template <> -void avx512_qsort(int64_t *arr, int64_t arrsize) +void qsort_>(int64_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - qsort_64bit_, int64_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_64bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(uint64_t *arr, int64_t arrsize) +void qsort_>(uint64_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - qsort_64bit_, uint64_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_64bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(double *arr, int64_t arrsize) +void qsort_>(double *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_64bit_, double>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } + qsort_64bit_>(arr, left, right, maxiters); } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 0ae50c49..e829ab62 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -15,17 +15,6 @@ using argtype = zmm_vector; using argzmm_t = typename argtype::zmm_t; -template -void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); - -template -std::vector avx512_argsort(T *arr, int64_t arrsize); - -template -void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize); - -template -std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize); /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -45,7 +34,8 @@ static inline int32_t partition_vec(type_t *arg, int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); argtype::mask_compressstoreu( arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu(arg + right - amount_gt_pivot, gt_mask, arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); *smallest_vec = vtype::min(curr_vec, *smallest_vec); *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; @@ -236,7 +226,8 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, right -= num_unroll * vtype::numlanes; #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + right + ii * vtype::numlanes); + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); curr_vec[ii] = vtype::template i64gather( arg_vec[ii], arr); } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 841b4a83..6e5cd15e 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -85,8 +85,8 @@ #define X86_SIMD_SORT_FINLINE static #endif -#define LIKELY(x) __builtin_expect((x),1) -#define UNLIKELY(x) __builtin_expect((x),0) +#define LIKELY(x) __builtin_expect((x), 1) +#define UNLIKELY(x) __builtin_expect((x), 0) template struct zmm_vector; @@ -94,35 +94,76 @@ struct zmm_vector; template struct ymm_vector; -// Regular quicksort routines: template -void avx512_qsort(T *arr, int64_t arrsize); -void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); - -template -void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); - -template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +bool is_a_nan(T elem) { - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1); + return std::isnan(elem); } -inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false) + +template +int64_t replace_nan_with_inf(T *arr, int64_t arrsize) { - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); - avx512_qsort_fp16(arr, k - 1); + int64_t nan_count = 0; + using opmask_t = typename vtype::opmask_t; + using zmm_t = typename vtype::zmm_t; + opmask_t loadmask; + zmm_t in; + while (arrsize > 0) { + if (arrsize < vtype::numlanes) { + loadmask = vtype::get_partial_loadmask(arrsize); + in = vtype::maskz_loadu(loadmask, arr); + } + else { + in = vtype::loadu(arr); + } + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + vtype::mask_storeu(arr, nanmask, vtype::zmm_max()); + arr += vtype::numlanes; + arrsize -= vtype::numlanes; + } + return nan_count; } -// key-value sort routines -template -void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); +template +bool has_nan(type_t *arr, int64_t arrsize) +{ + using opmask_t = typename vtype::opmask_t; + using zmm_t = typename vtype::zmm_t; + bool found_nan = false; + opmask_t loadmask; + zmm_t in; + while (arrsize > 0) { + if (arrsize < vtype::numlanes) { + loadmask = vtype::get_partial_loadmask(arrsize); + in = vtype::maskz_loadu(loadmask, arr); + } + else { + in = vtype::loadu(arr); + } + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + arr += vtype::numlanes; + arrsize -= vtype::numlanes; + if (nanmask != 0x00) { + found_nan = true; + break; + } + } + return found_nan; +} -template -bool is_a_nan(T elem) +template +void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count) { - return std::isnan(elem); + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + if constexpr (std::is_floating_point_v) { + arr[ii] = std::numeric_limits::quiet_NaN(); + } + else { + arr[ii] = 0xFFFF; + } + nan_count -= 1; + } } /* @@ -130,7 +171,7 @@ bool is_a_nan(T elem) * in the array which is not a nan */ template -int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) +int64_t move_nans_to_end_of_array(T *arr, int64_t arrsize) { int64_t jj = arrsize - 1; int64_t ii = 0; @@ -145,7 +186,7 @@ int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) ii += 1; } } - return arrsize-count-1; + return arrsize - count - 1; } template @@ -628,4 +669,73 @@ static inline int64_t partition_avx512(type_t1 *keys, *biggest = vtype1::reducemax(max_vec); return l_store; } + +template +void qsort_(type_t *arr, int64_t left, int64_t right, int64_t maxiters); + +template +void qselect_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t maxiters); + +// Regular quicksort routines: +template +void avx512_qsort(T *arr, int64_t arrsize) +{ + if (arrsize > 1) { + /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ + if constexpr (std::is_floating_point_v) { + int64_t nan_count + = replace_nan_with_inf>(arr, arrsize); + qsort_, T>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); + } + else { + qsort_, T>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); + +template +void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + int64_t indx_last_elem = arrsize - 1; + /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ + if constexpr (std::is_floating_point_v) { + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + } + if (indx_last_elem >= k) { + qselect_, T>( + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + +void avx512_qselect_fp16(uint16_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan = false); + +template +inline void +avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + avx512_qselect(arr, k - 1, arrsize, hasnan); + avx512_qsort(arr, k - 1); +} +inline void avx512_partial_qsort_fp16(uint16_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan = false) +{ + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); + avx512_qsort_fp16(arr, k - 1); +} #endif // AVX512_QSORT_COMMON diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 5bb4c6c0..505561c4 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -46,11 +46,19 @@ struct zmm_vector<_Float16> { { return _knot_mask32(x); } - static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x00000001 << size) - 0x00000001; + } + template + static opmask_t fpclass(zmm_t x) + { + return _mm512_fpclass_ph_mask(x, type); + } static zmm_t loadu(void const *mem) { return _mm512_loadu_ph(mem); @@ -65,6 +73,10 @@ struct zmm_vector<_Float16> { // AVX512_VBMI2 return _mm512_mask_compressstoreu_epi16(mem, mask, temp); } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, mem)); + } static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { // AVX512BW @@ -114,62 +126,44 @@ struct zmm_vector<_Float16> { } }; -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, - int64_t arrsize) +template <> +bool is_a_nan<_Float16>(_Float16 elem) { - int64_t nan_count = 0; - __mmask32 loadmask = 0xFFFFFFFF; - __m512h in_zmm; - while (arrsize > 0) { - if (arrsize < 32) { - loadmask = (0x00000001 << arrsize) - 0x00000001; - in_zmm = _mm512_castsi512_ph( - _mm512_maskz_loadu_epi16(loadmask, arr)); - } - else { - in_zmm = _mm512_loadu_ph(arr); - } - __mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF); - arr += 32; - arrsize -= 32; - } - return nan_count; + Fp16Bits temp; + temp.f_ = elem; + return (temp.i_ & 0x7c00) == 0x7c00; } -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) +template <> +void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) { memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); } template <> -bool is_a_nan<_Float16>(_Float16 elem) +void qselect_>( + _Float16 *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - Fp16Bits temp; - temp.f_ = elem; - return (temp.i_ & 0x7c00) == 0x7c00; + qselect_16bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) +void qsort_>(_Float16 *arr, + int64_t left, + int64_t right, + int64_t maxiters) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_16bit_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + qsort_16bit_>(arr, left, right, maxiters); } +/* Specialized template function for _Float16 qsort_*/ template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t nan_count + = replace_nan_with_inf, _Float16>(arr, + arrsize); qsort_16bit_, _Float16>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index a05e9528..6e75f344 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -54,7 +54,7 @@ TYPED_TEST_P(KeyValueSort, test_64bit_random_data) std::sort(sortedarr.begin(), sortedarr.end(), compare); - avx512_qsort_kv(keys.data(), values.data(), keys.size()); + avx512_qsort_kv(keys.data(), values.data(), keys.size()); for (size_t i = 0; i < keys.size(); i++) { ASSERT_EQ(keys[i], sortedarr[i].key); ASSERT_EQ(values[i], sortedarr[i].value);