From 720e1f7d08e9ac2536d74baf96f6b190a57590a5 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 4 Aug 2023 14:19:45 -0700 Subject: [PATCH 1/8] Remove template specializations for quicksort methods --- src/avx512-16bit-qsort.hpp | 30 +++------ src/avx512-32bit-qsort.hpp | 58 +++++------------ src/avx512-64bit-common.h | 23 ------- src/avx512-64bit-keyvaluesort.hpp | 2 +- src/avx512-64bit-qsort.hpp | 23 ++----- src/avx512-common-qsort.h | 103 +++++++++++++++++++++++------- src/avx512fp16-16bit-qsort.hpp | 39 +---------- 7 files changed, 113 insertions(+), 165 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index b5202f46..b5252932 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -377,8 +377,9 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, - int64_t arrsize) +template<> +int64_t +replace_nan_with_inf>(uint16_t *arr, int64_t arrsize) { int64_t nan_count = 0; __mmask16 loadmask = 0xFFFF; @@ -396,15 +397,6 @@ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr, return nan_count; } -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = 0xFFFF; - nan_count -= 1; - } -} - template <> bool is_a_nan(uint16_t elem) { @@ -442,27 +434,21 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) } template <> -void avx512_qsort(int16_t *arr, int64_t arrsize) +void qsort_>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_16bit_, int16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_16bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(uint16_t *arr, int64_t arrsize) +void qsort_>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_16bit_, uint16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_16bit_>(arr, left, right, maxiters); } void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t nan_count = replace_nan_with_inf, uint16_t>(arr, arrsize); qsort_16bit_, uint16_t>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index a0dd7f7e..5a1571cb 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -256,6 +256,11 @@ struct zmm_vector { { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } + template + static opmask_t fpclass(zmm_t x) + { + return _mm512_fpclass_ps_mask(x, type); + } template static ymm_t i64gather(__m512i index, void const *base) { @@ -279,6 +284,10 @@ struct zmm_vector { { return _mm512_mask_compressstoreu_ps(mem, mask, x); } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm512_maskz_loadu_ps(mask, mem); + } static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_ps(x, mask, mem); @@ -689,31 +698,6 @@ static void qselect_32bit_(type_t *arr, qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask16 loadmask = 0xFFFF; - while (arrsize > 0) { - if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; } - __m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr); - __mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT); - arr += 16; - arrsize -= 16; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nanf("1"); - nan_count -= 1; - } -} - template <> void avx512_qselect(int32_t *arr, int64_t k, @@ -752,32 +736,20 @@ void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) } template <> -void avx512_qsort(int32_t *arr, int64_t arrsize) +void qsort_>(int32_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_32bit_, int32_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_32bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(uint32_t *arr, int64_t arrsize) +void qsort_>(uint32_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_32bit_, uint32_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_32bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(float *arr, int64_t arrsize) +void qsort_>(float* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_32bit_, float>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } + qsort_32bit_>(arr, left, right, maxiters); } - #endif //AVX512_QSORT_32BIT diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index d12684c1..3b6d2a58 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -773,30 +773,7 @@ struct zmm_vector { _mm512_storeu_pd(mem, x); } }; -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); - __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); - arr += 8; - arrsize -= 8; - } - return nan_count; -} -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; - } -} /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index f721f5c8..b39e68d4 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -463,7 +463,7 @@ template <> void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(keys, arrsize); + int64_t nan_count = replace_nan_with_inf>(keys, arrsize); qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(keys, arrsize, nan_count); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index d59a1788..f1e58bbc 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -824,31 +824,20 @@ void avx512_qselect(double *arr, } template <> -void avx512_qsort(int64_t *arr, int64_t arrsize) +void qsort_>(int64_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_64bit_, int64_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_64bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(uint64_t *arr, int64_t arrsize) +void qsort_>(uint64_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qsort_64bit_, uint64_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_64bit_>(arr, left, right, maxiters); } template <> -void avx512_qsort(double *arr, int64_t arrsize) +void qsort_>(double* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_64bit_, double>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } + qsort_64bit_>(arr, left, right, maxiters); } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 841b4a83..2a68f7f2 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -94,35 +94,50 @@ struct zmm_vector; template struct ymm_vector; -// Regular quicksort routines: template -void avx512_qsort(T *arr, int64_t arrsize); -void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); - -template -void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); - -template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +bool is_a_nan(T elem) { - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1); + return std::isnan(elem); } -inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false) + +template +int64_t replace_nan_with_inf(type_t *arr, int64_t arrsize) { - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); - avx512_qsort_fp16(arr, k - 1); + int64_t nan_count = 0; + using opmask_t = typename vtype::opmask_t; + using zmm_t = typename vtype::zmm_t; + bool found_nan = false; + opmask_t loadmask = 0xFF; + zmm_t in; + while (arrsize > 0) { + if (arrsize < vtype::numlanes) { + loadmask = (0x01 << arrsize) - 0x01; + in = vtype::maskz_loadu(loadmask, arr); + } + else { + in = vtype::loadu(arr); + } + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + vtype::mask_storeu(arr, nanmask, vtype::zmm_max()); + arr += vtype::numlanes; + arrsize -= vtype::numlanes; + } + return nan_count; } -// key-value sort routines -template -void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); - -template -bool is_a_nan(T elem) +template +void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count) { - return std::isnan(elem); + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + if constexpr (std::is_floating_point_v) { + arr[ii] = std::numeric_limits::quiet_NaN(); + } + else { + arr[ii] = 0xFFFF; + } + nan_count -= 1; + } } /* @@ -628,4 +643,48 @@ static inline int64_t partition_avx512(type_t1 *keys, *biggest = vtype1::reducemax(max_vec); return l_store; } + +template +void qsort_(type_t* arr, int64_t left, int64_t right, int64_t maxiters); + +// Regular quicksort routines: +template +void avx512_qsort(T *arr, int64_t arrsize) +{ + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + int64_t nan_count = replace_nan_with_inf>(arr, arrsize); + qsort_, T>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); + } + else { + qsort_, T>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); + +template +void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); +void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); + +template +inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + avx512_qselect(arr, k - 1, arrsize, hasnan); + avx512_qsort(arr, k - 1); +} +inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); + avx512_qsort_fp16(arr, k - 1); +} + +// key-value sort routines +template +void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); + #endif // AVX512_QSORT_COMMON diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 5bb4c6c0..1b59b272 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -114,36 +114,6 @@ struct zmm_vector<_Float16> { } }; -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, - int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask32 loadmask = 0xFFFFFFFF; - __m512h in_zmm; - while (arrsize > 0) { - if (arrsize < 32) { - loadmask = (0x00000001 << arrsize) - 0x00000001; - in_zmm = _mm512_castsi512_ph( - _mm512_maskz_loadu_epi16(loadmask, arr)); - } - else { - in_zmm = _mm512_loadu_ph(arr); - } - __mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF); - arr += 32; - arrsize -= 32; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) -{ - memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); -} - template <> bool is_a_nan<_Float16>(_Float16 elem) { @@ -166,13 +136,8 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) } template <> -void avx512_qsort(_Float16 *arr, int64_t arrsize) +void qsort_>(_Float16* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_16bit_, _Float16>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } + qsort_16bit_>(arr, left, right, maxiters); } #endif // AVX512FP16_QSORT_16BIT From 51fe74381da46f2dd13d677fc94ba7266ee95ecc Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 4 Aug 2023 14:31:24 -0700 Subject: [PATCH 2/8] update meson to default to c++-17 --- meson.build | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/meson.build b/meson.build index ca28c0a1..a10598f9 100644 --- a/meson.build +++ b/meson.build @@ -1,6 +1,7 @@ project('x86-simd-sort', 'cpp', - version : '1.0.0', - license : 'BSD 3-clause') + version : '2.0.0', + license : 'BSD 3-clause', + default_options : ['cpp_std=c++17']) cpp = meson.get_compiler('cpp') src = include_directories('src') bench = include_directories('benchmarks') From 92a628abd4bf9aa0b0634e7998e9b409fe15de7a Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 4 Aug 2023 10:10:03 -0700 Subject: [PATCH 3/8] Move argsort and argselect template func defintion to header --- src/avx512-64bit-argsort.hpp | 17 ----------------- src/avx512-common-argsort.h | 19 ++++++++++++++++--- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 3626ab63..60afdd13 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -427,15 +427,6 @@ void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize) } } -template -std::vector avx512_argsort(T *arr, int64_t arrsize) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize); - return indices; -} - /* argselect methods for 32-bit and 64-bit dtypes */ template void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize) @@ -492,13 +483,5 @@ void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize) } } -template -std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argselect(arr, indices.data(), k, arrsize); - return indices; -} #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 0ae50c49..628925a2 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -19,13 +19,26 @@ template void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); template -std::vector avx512_argsort(T *arr, int64_t arrsize); +void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize); template -void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize); +std::vector avx512_argsort(T *arr, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize); + return indices; +} template -std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize); +std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize); + return indices; +} + /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. From 3ddc914c9cc0665c26c3a7b4c4be4bfb811db52e Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 7 Aug 2023 14:09:44 -0700 Subject: [PATCH 4/8] Fix bug in avx512fp16 nan processing --- src/avx512-32bit-qsort.hpp | 4 ++++ src/avx512-64bit-argsort.hpp | 27 ------------------------- src/avx512-64bit-common.h | 8 ++++++++ src/avx512-common-qsort.h | 36 +++++++++++++++++++++++++++++----- src/avx512fp16-16bit-qsort.hpp | 32 +++++++++++++++++++++++++++++- 5 files changed, 74 insertions(+), 33 deletions(-) diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 5a1571cb..cb24b547 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -256,6 +256,10 @@ struct zmm_vector { { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x0001 << size) - 0x0001; + } template static opmask_t fpclass(zmm_t x) { diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 60afdd13..8d00449c 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -344,33 +344,6 @@ static void argselect_64bit_(type_t *arr, arr, arg, pos, pivot_index, right, max_iters - 1); } -template -bool has_nan(type_t *arr, int64_t arrsize) -{ - using opmask_t = typename vtype::opmask_t; - using zmm_t = typename vtype::zmm_t; - bool found_nan = false; - opmask_t loadmask = 0xFF; - zmm_t in; - while (arrsize > 0) { - if (arrsize < vtype::numlanes) { - loadmask = (0x01 << arrsize) - 0x01; - in = vtype::maskz_loadu(loadmask, arr); - } - else { - in = vtype::loadu(arr); - } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); - arr += vtype::numlanes; - arrsize -= vtype::numlanes; - if (nanmask != 0x00) { - found_nan = true; - break; - } - } - return found_nan; -} - /* argsort methods for 32-bit and 64-bit dtypes */ template void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 3b6d2a58..edaf6b84 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -71,6 +71,10 @@ struct ymm_vector { { return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x01 << size) - 0x01; + } template static opmask_t fpclass(zmm_t x) { @@ -703,6 +707,10 @@ struct zmm_vector { { return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x01 << size) - 0x01; + } template static opmask_t fpclass(zmm_t x) { diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 2a68f7f2..f836a41a 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -100,18 +100,17 @@ bool is_a_nan(T elem) return std::isnan(elem); } -template -int64_t replace_nan_with_inf(type_t *arr, int64_t arrsize) +template +int64_t replace_nan_with_inf(T *arr, int64_t arrsize) { int64_t nan_count = 0; using opmask_t = typename vtype::opmask_t; using zmm_t = typename vtype::zmm_t; - bool found_nan = false; - opmask_t loadmask = 0xFF; + opmask_t loadmask; zmm_t in; while (arrsize > 0) { if (arrsize < vtype::numlanes) { - loadmask = (0x01 << arrsize) - 0x01; + loadmask = vtype::get_partial_loadmask(arrsize); in = vtype::maskz_loadu(loadmask, arr); } else { @@ -126,6 +125,33 @@ int64_t replace_nan_with_inf(type_t *arr, int64_t arrsize) return nan_count; } +template +bool has_nan(type_t *arr, int64_t arrsize) +{ + using opmask_t = typename vtype::opmask_t; + using zmm_t = typename vtype::zmm_t; + bool found_nan = false; + opmask_t loadmask; + zmm_t in; + while (arrsize > 0) { + if (arrsize < vtype::numlanes) { + loadmask = vtype::get_partial_loadmask(arrsize); + in = vtype::maskz_loadu(loadmask, arr); + } + else { + in = vtype::loadu(arr); + } + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + arr += vtype::numlanes; + arrsize -= vtype::numlanes; + if (nanmask != 0x00) { + found_nan = true; + break; + } + } + return found_nan; +} + template void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 1b59b272..53f617d7 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -46,11 +46,19 @@ struct zmm_vector<_Float16> { { return _knot_mask32(x); } - static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); } + static opmask_t get_partial_loadmask(int size) + { + return (0x00000001 << size) - 0x00000001; + } + template + static opmask_t fpclass(zmm_t x) + { + return _mm512_fpclass_ph_mask(x, type); + } static zmm_t loadu(void const *mem) { return _mm512_loadu_ph(mem); @@ -65,6 +73,11 @@ struct zmm_vector<_Float16> { // AVX512_VBMI2 return _mm512_mask_compressstoreu_epi16(mem, mask, temp); } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm512_castsi512_ph( + _mm512_maskz_loadu_epi16(mask, mem)); + } static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { // AVX512BW @@ -140,4 +153,21 @@ void qsort_>(_Float16* arr, int64_t left, int64_t right, in { qsort_16bit_>(arr, left, right, maxiters); } + +template<> +void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) +{ + memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); +} + +template<> +void avx512_qsort(_Float16 *arr, int64_t arrsize) +{ + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf, _Float16>(arr, arrsize); + qsort_16bit_, _Float16>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); + } +} #endif // AVX512FP16_QSORT_16BIT From 472c7d00fd44b802cb366a1fd601ff999eeceab3 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 8 Aug 2023 09:59:33 -0700 Subject: [PATCH 5/8] minimize template specialization for avx512_qselect --- src/avx512-16bit-qsort.hpp | 53 +++++++++++++++------------------- src/avx512-32bit-qsort.hpp | 33 +++++---------------- src/avx512-64bit-qsort.hpp | 36 +++++------------------ src/avx512-common-qsort.h | 20 ++++++++++++- src/avx512fp16-16bit-qsort.hpp | 25 +++++++--------- 5 files changed, 69 insertions(+), 98 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index b5252932..d7121f2d 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -403,56 +403,51 @@ bool is_a_nan(uint16_t elem) return (elem & 0x7c00) == 0x7c00; } +/* Specialized template function for 16-bit qsort_ funcs*/ template <> -void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void qsort_>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_16bit_, int16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_16bit_>(arr, left, right, maxiters); } template <> -void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void qsort_>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_16bit_, uint16_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qsort_16bit_>(arr, left, right, maxiters); } -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_16bit_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf, uint16_t>(arr, arrsize); + qsort_16bit_, uint16_t>( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } +/* Specialized template function for 16-bit qselect_ funcs*/ template <> -void qsort_>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qselect_>(int16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - qsort_16bit_>(arr, left, right, maxiters); + qselect_16bit_>(arr, k, left, right, maxiters); } template <> -void qsort_>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qselect_>(uint16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - qsort_16bit_>(arr, left, right, maxiters); + qselect_16bit_>(arr, k, left, right, maxiters); } -void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) +void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf, uint16_t>(arr, arrsize); - qsort_16bit_, uint16_t>( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + int64_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + if (indx_last_elem >= k) { + qselect_16bit_, uint16_t>( + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } - #endif // AVX512_QSORT_16BIT diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index cb24b547..79c49f32 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -702,43 +702,26 @@ static void qselect_32bit_(type_t *arr, qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); } +/* Specialized template function for 32-bit qselect_ funcs*/ template <> -void avx512_qselect(int32_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>(int32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_32bit_, int32_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_32bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(uint32_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>(uint32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_32bit_, uint32_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_32bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) +void qselect_>(float* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_32bit_, float>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + qselect_32bit_>(arr, k, left, right, maxiters); } +/* Specialized template function for 32-bit qsort_ funcs*/ template <> void qsort_>(int32_t* arr, int64_t left, int64_t right, int64_t maxiters) { diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index f1e58bbc..d9e6aff9 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -783,46 +783,26 @@ static void qselect_64bit_(type_t *arr, qselect_64bit_(arr, pos, pivot_index, right, max_iters - 1); } +/* Specialized template function for 64-bit qselect_ funcs*/ template <> -void avx512_qselect(int64_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>(int64_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_64bit_, int64_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_64bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(uint64_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>(uint64_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - if (arrsize > 1) { - qselect_64bit_, uint64_t>( - arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } + qselect_64bit_>(arr, k, left, right, maxiters); } template <> -void avx512_qselect(double *arr, - int64_t k, - int64_t arrsize, - bool hasnan) +void qselect_>(double* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_64bit_, double>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + qselect_64bit_>(arr, k, left, right, maxiters); } +/* Specialized template function for 64-bit qsort_ funcs*/ template <> void qsort_>(int64_t* arr, int64_t left, int64_t right, int64_t maxiters) { diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index f836a41a..be3b9f0a 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -673,11 +673,15 @@ static inline int64_t partition_avx512(type_t1 *keys, template void qsort_(type_t* arr, int64_t left, int64_t right, int64_t maxiters); +template +void qselect_(type_t* arr, int64_t pos, int64_t left, int64_t right, int64_t maxiters); + // Regular quicksort routines: template void avx512_qsort(T *arr, int64_t arrsize) { if (arrsize > 1) { + /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { int64_t nan_count = replace_nan_with_inf>(arr, arrsize); qsort_, T>( @@ -694,7 +698,21 @@ void avx512_qsort(T *arr, int64_t arrsize) void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); template -void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false); +void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + int64_t indx_last_elem = arrsize - 1; + /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ + if constexpr (std::is_floating_point_v) { + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + } + if (indx_last_elem >= k) { + qselect_, T>( + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); template diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 53f617d7..1876be4f 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -135,31 +135,26 @@ bool is_a_nan<_Float16>(_Float16 elem) return (temp.i_ & 0x7c00) == 0x7c00; } -template <> -void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) +template<> +void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) { - int64_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_16bit_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); } template <> -void qsort_>(_Float16* arr, int64_t left, int64_t right, int64_t maxiters) +void qselect_>(_Float16* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { - qsort_16bit_>(arr, left, right, maxiters); + qselect_16bit_>(arr, k, left, right, maxiters); } -template<> -void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) + +template <> +void qsort_>(_Float16* arr, int64_t left, int64_t right, int64_t maxiters) { - memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); + qsort_16bit_>(arr, left, right, maxiters); } +/* Specialized template function for _Float16 qsort_*/ template<> void avx512_qsort(_Float16 *arr, int64_t arrsize) { From 34c2798067832789a8d3ddbb7c7725c40dcd7be5 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 8 Aug 2023 10:14:34 -0700 Subject: [PATCH 6/8] generalize key-value sort for all 64-bit combinations --- src/avx512-64bit-keyvaluesort.hpp | 38 ++++++++++--------------------- src/avx512-common-qsort.h | 5 ---- tests/test-keyvalue.cpp | 2 +- 3 files changed, 13 insertions(+), 32 deletions(-) diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index b39e68d4..31b6aaba 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -439,34 +439,20 @@ void qsort_64bit_(type1_t *keys, } } -template <> -void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) +template +void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, zmm_vector>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_qsort_kv(uint64_t *keys, - uint64_t *indexes, - int64_t arrsize) -{ - if (arrsize > 1) { - qsort_64bit_, zmm_vector>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) -{ - if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf>(keys, arrsize); - qsort_64bit_, zmm_vector>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(keys, arrsize, nan_count); + if constexpr (std::is_floating_point_v) { + int64_t nan_count = replace_nan_with_inf>(keys, arrsize); + qsort_64bit_, zmm_vector>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(keys, arrsize, nan_count); + } + else { + qsort_64bit_, zmm_vector>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } } } #endif // AVX512_QSORT_64BIT_KV diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index be3b9f0a..e98ba1cf 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -726,9 +726,4 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); avx512_qsort_fp16(arr, k - 1); } - -// key-value sort routines -template -void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); - #endif // AVX512_QSORT_COMMON diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index a05e9528..6e75f344 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -54,7 +54,7 @@ TYPED_TEST_P(KeyValueSort, test_64bit_random_data) std::sort(sortedarr.begin(), sortedarr.end(), compare); - avx512_qsort_kv(keys.data(), values.data(), keys.size()); + avx512_qsort_kv(keys.data(), values.data(), keys.size()); for (size_t i = 0; i < keys.size(); i++) { ASSERT_EQ(keys[i], sortedarr[i].key); ASSERT_EQ(values[i], sortedarr[i].value); From bdd0af6f3aea10bc8b543a16056441a45a20f4f9 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 8 Aug 2023 10:49:54 -0700 Subject: [PATCH 7/8] Remove template specializations for arg methods --- src/avx512-64bit-argsort.hpp | 116 +++++++++-------------------------- src/avx512-common-argsort.h | 24 -------- 2 files changed, 30 insertions(+), 110 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 8d00449c..c6499a6c 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -348,113 +348,57 @@ static void argselect_64bit_(type_t *arr, template void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize) { + using vectype = typename std::conditional, + zmm_vector>::type; if (arrsize > 1) { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argsort_withnan(arr, arg, 0, arrsize); + if constexpr (std::is_floating_point_v) { + if (has_nan(arr, arrsize)) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } } - else { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } - } -} - -template <> -void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - argsort_64bit_>( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } -template <> -void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize) +template +std::vector avx512_argsort(T *arr, int64_t arrsize) { - if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argsort_withnan(arr, arg, 0, arrsize); - } - else { - argsort_64bit_>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } - } + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize); + return indices; } /* argselect methods for 32-bit and 64-bit dtypes */ template void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} + using vectype = typename std::conditional, + zmm_vector>::type; -template <> -void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - } - else { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + if constexpr (std::is_floating_point_v) { + if (has_nan(arr, arrsize)) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } } - } -} - -template <> -void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ - if (arrsize > 1) { - argselect_64bit_>( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } -template <> -void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) -{ - if (arrsize > 1) { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize) +template +std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) { - if (arrsize > 1) { - if (has_nan>(arr, arrsize)) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - } - else { - argselect_64bit_>( - arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } - } + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize); + return indices; } - #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 628925a2..9bdaf464 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -15,30 +15,6 @@ using argtype = zmm_vector; using argzmm_t = typename argtype::zmm_t; -template -void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); - -template -void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize); - -template -std::vector avx512_argsort(T *arr, int64_t arrsize) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize); - return indices; -} - -template -std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argselect(arr, indices.data(), k, arrsize); - return indices; -} - /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. From 4efed2a8b56a355cbcf12d0e316f039c0e4e95c7 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 8 Aug 2023 10:57:48 -0700 Subject: [PATCH 8/8] Style format --- src/avx512-16bit-qsort.hpp | 25 ++++-- src/avx512-32bit-qsort.hpp | 24 ++++-- src/avx512-64bit-common.h | 134 +++++++++++++----------------- src/avx512-64bit-keyvaluesort.hpp | 3 +- src/avx512-64bit-qsort.hpp | 24 ++++-- src/avx512-common-argsort.h | 6 +- src/avx512-common-qsort.h | 34 +++++--- src/avx512fp16-16bit-qsort.hpp | 20 +++-- 8 files changed, 150 insertions(+), 120 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index d7121f2d..2cdb45e7 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -377,9 +377,9 @@ bool comparison_func>(const uint16_t &a, const uint16_t &b) //return npy_half_to_float(a) < npy_half_to_float(b); } -template<> -int64_t -replace_nan_with_inf>(uint16_t *arr, int64_t arrsize) +template <> +int64_t replace_nan_with_inf>(uint16_t *arr, + int64_t arrsize) { int64_t nan_count = 0; __mmask16 loadmask = 0xFFFF; @@ -405,13 +405,19 @@ bool is_a_nan(uint16_t elem) /* Specialized template function for 16-bit qsort_ funcs*/ template <> -void qsort_>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(int16_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_16bit_>(arr, left, right, maxiters); } template <> -void qsort_>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(uint16_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_16bit_>(arr, left, right, maxiters); } @@ -419,7 +425,8 @@ void qsort_>(uint16_t* arr, int64_t left, int64_t right, in void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf, uint16_t>(arr, arrsize); + int64_t nan_count = replace_nan_with_inf, uint16_t>( + arr, arrsize); qsort_16bit_, uint16_t>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); @@ -428,13 +435,15 @@ void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) /* Specialized template function for 16-bit qselect_ funcs*/ template <> -void qselect_>(int16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + int16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_16bit_>(arr, k, left, right, maxiters); } template <> -void qselect_>(uint16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + uint16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_16bit_>(arr, k, left, right, maxiters); } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 79c49f32..054e4b26 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -704,38 +704,50 @@ static void qselect_32bit_(type_t *arr, /* Specialized template function for 32-bit qselect_ funcs*/ template <> -void qselect_>(int32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_32bit_>(arr, k, left, right, maxiters); } template <> -void qselect_>(uint32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_32bit_>(arr, k, left, right, maxiters); } template <> -void qselect_>(float* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_32bit_>(arr, k, left, right, maxiters); } /* Specialized template function for 32-bit qsort_ funcs*/ template <> -void qsort_>(int32_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(int32_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_32bit_>(arr, left, right, maxiters); } template <> -void qsort_>(uint32_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(uint32_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_32bit_>(arr, left, right, maxiters); } template <> -void qsort_>(float* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(float *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_32bit_>(arr, left, right, maxiters); } diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index edaf6b84..fbd4a88f 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -40,14 +40,8 @@ struct ymm_vector { return _mm256_set1_ps(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -93,7 +87,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_ps((float*) mem); + return _mm256_loadu_ps((float *)mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -129,16 +123,22 @@ struct ymm_vector { } static type_t reducemax(zmm_t v) { - __m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps (v, 1)); - __m128 v64 = _mm_max_ps(v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128 v32 = _mm_max_ps(v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), + _mm256_extractf32x4_ps(v, 1)); + __m128 v64 = _mm_max_ps( + v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128 v32 = _mm_max_ps( + v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); return _mm_cvtss_f32(v32); } static type_t reducemin(zmm_t v) { - __m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps(v, 1)); - __m128 v64 = _mm_min_ps(v128, _mm_shuffle_ps(v128, v128,_MM_SHUFFLE(1, 0, 3, 2))); - __m128 v32 = _mm_min_ps(v64, _mm_shuffle_ps(v64, v64,_MM_SHUFFLE(0, 0, 0, 1))); + __m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), + _mm256_extractf32x4_ps(v, 1)); + __m128 v64 = _mm_min_ps( + v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128 v32 = _mm_min_ps( + v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); return _mm_cvtss_f32(v32); } static zmm_t set1(type_t v) @@ -160,7 +160,7 @@ struct ymm_vector { } static void storeu(void *mem, zmm_t x) { - _mm256_storeu_ps((float*)mem, x); + _mm256_storeu_ps((float *)mem, x); } }; template <> @@ -184,14 +184,8 @@ struct ymm_vector { return _mm256_set1_epi32(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -228,7 +222,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_si256((__m256i*) mem); + return _mm256_loadu_si256((__m256i *)mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -264,16 +258,22 @@ struct ymm_vector { } static type_t reducemax(zmm_t v) { - __m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_max_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_max_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_max_epu32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_max_epu32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static type_t reducemin(zmm_t v) { - __m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_min_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_min_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_min_epu32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_min_epu32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static zmm_t set1(type_t v) @@ -289,7 +289,7 @@ struct ymm_vector { } static void storeu(void *mem, zmm_t x) { - _mm256_storeu_si256((__m256i*) mem, x); + _mm256_storeu_si256((__m256i *)mem, x); } }; template <> @@ -313,14 +313,8 @@ struct ymm_vector { return _mm256_set1_epi32(type_max()); } // TODO: this should broadcast bits as is? - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -357,7 +351,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_si256((__m256i*) mem); + return _mm256_loadu_si256((__m256i *)mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -393,16 +387,22 @@ struct ymm_vector { } static type_t reducemax(zmm_t v) { - __m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_max_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_max_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_max_epi32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_max_epi32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static type_t reducemin(zmm_t v) { - __m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); - __m128i v64 = _mm_min_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); - __m128i v32 = _mm_min_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + __m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_min_epi32( + v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_min_epi32( + v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } static zmm_t set1(type_t v) @@ -418,7 +418,7 @@ struct ymm_vector { } static void storeu(void *mem, zmm_t x) { - _mm256_storeu_si256((__m256i*) mem, x); + _mm256_storeu_si256((__m256i *)mem, x); } }; template <> @@ -443,14 +443,8 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } // TODO: this should broadcast bits as is? - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -567,14 +561,8 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -679,14 +667,8 @@ struct zmm_vector { return _mm512_set1_pd(type_max()); } - static zmmi_t seti(int v1, - int v2, - int v3, - int v4, - int v5, - int v6, - int v7, - int v8) + static zmmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -793,16 +775,12 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge( - zmm, - vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), - 0xCC); + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); zmm = cmp_merge( - zmm, - vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), - 0xCC); + zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); return zmm; diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 31b6aaba..16f8d354 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -444,7 +444,8 @@ void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize) { if (arrsize > 1) { if constexpr (std::is_floating_point_v) { - int64_t nan_count = replace_nan_with_inf>(keys, arrsize); + int64_t nan_count + = replace_nan_with_inf>(keys, arrsize); qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(keys, arrsize, nan_count); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index d9e6aff9..626e672e 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -785,38 +785,50 @@ static void qselect_64bit_(type_t *arr, /* Specialized template function for 64-bit qselect_ funcs*/ template <> -void qselect_>(int64_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + int64_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_64bit_>(arr, k, left, right, maxiters); } template <> -void qselect_>(uint64_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + uint64_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_64bit_>(arr, k, left, right, maxiters); } template <> -void qselect_>(double* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + double *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_64bit_>(arr, k, left, right, maxiters); } /* Specialized template function for 64-bit qsort_ funcs*/ template <> -void qsort_>(int64_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(int64_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_64bit_>(arr, left, right, maxiters); } template <> -void qsort_>(uint64_t* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(uint64_t *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_64bit_>(arr, left, right, maxiters); } template <> -void qsort_>(double* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(double *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_64bit_>(arr, left, right, maxiters); } diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 9bdaf464..e829ab62 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -34,7 +34,8 @@ static inline int32_t partition_vec(type_t *arg, int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); argtype::mask_compressstoreu( arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu(arg + right - amount_gt_pivot, gt_mask, arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); *smallest_vec = vtype::min(curr_vec, *smallest_vec); *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; @@ -225,7 +226,8 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, right -= num_unroll * vtype::numlanes; #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + right + ii * vtype::numlanes); + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); curr_vec[ii] = vtype::template i64gather( arg_vec[ii], arr); } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index e98ba1cf..6e5cd15e 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -85,8 +85,8 @@ #define X86_SIMD_SORT_FINLINE static #endif -#define LIKELY(x) __builtin_expect((x),1) -#define UNLIKELY(x) __builtin_expect((x),0) +#define LIKELY(x) __builtin_expect((x), 1) +#define UNLIKELY(x) __builtin_expect((x), 0) template struct zmm_vector; @@ -152,7 +152,7 @@ bool has_nan(type_t *arr, int64_t arrsize) return found_nan; } -template +template void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count) { for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { @@ -171,7 +171,7 @@ void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count) * in the array which is not a nan */ template -int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) +int64_t move_nans_to_end_of_array(T *arr, int64_t arrsize) { int64_t jj = arrsize - 1; int64_t ii = 0; @@ -186,7 +186,7 @@ int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize) ii += 1; } } - return arrsize-count-1; + return arrsize - count - 1; } template @@ -671,10 +671,14 @@ static inline int64_t partition_avx512(type_t1 *keys, } template -void qsort_(type_t* arr, int64_t left, int64_t right, int64_t maxiters); +void qsort_(type_t *arr, int64_t left, int64_t right, int64_t maxiters); template -void qselect_(type_t* arr, int64_t pos, int64_t left, int64_t right, int64_t maxiters); +void qselect_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t maxiters); // Regular quicksort routines: template @@ -683,7 +687,8 @@ void avx512_qsort(T *arr, int64_t arrsize) if (arrsize > 1) { /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { - int64_t nan_count = replace_nan_with_inf>(arr, arrsize); + int64_t nan_count + = replace_nan_with_inf>(arr, arrsize); qsort_, T>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); @@ -713,15 +718,22 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) } } -void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false); +void avx512_qselect_fp16(uint16_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan = false); template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +inline void +avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) { avx512_qselect(arr, k - 1, arrsize, hasnan); avx512_qsort(arr, k - 1); } -inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false) +inline void avx512_partial_qsort_fp16(uint16_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan = false) { avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); avx512_qsort_fp16(arr, k - 1); diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 1876be4f..505561c4 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -75,8 +75,7 @@ struct zmm_vector<_Float16> { } static zmm_t maskz_loadu(opmask_t mask, void const *mem) { - return _mm512_castsi512_ph( - _mm512_maskz_loadu_epi16(mask, mem)); + return _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, mem)); } static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { @@ -135,31 +134,36 @@ bool is_a_nan<_Float16>(_Float16 elem) return (temp.i_ & 0x7c00) == 0x7c00; } -template<> +template <> void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) { memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); } template <> -void qselect_>(_Float16* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) +void qselect_>( + _Float16 *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) { qselect_16bit_>(arr, k, left, right, maxiters); } - template <> -void qsort_>(_Float16* arr, int64_t left, int64_t right, int64_t maxiters) +void qsort_>(_Float16 *arr, + int64_t left, + int64_t right, + int64_t maxiters) { qsort_16bit_>(arr, left, right, maxiters); } /* Specialized template function for _Float16 qsort_*/ -template<> +template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf, _Float16>(arr, arrsize); + int64_t nan_count + = replace_nan_with_inf, _Float16>(arr, + arrsize); qsort_16bit_, _Float16>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count);