diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 0c819946..e51ac14a 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -8,6 +8,7 @@ #define AVX512_16BIT_COMMON #include "avx512-common-qsort.h" +#include "xss-network-qsort.hpp" /* * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic @@ -33,8 +34,8 @@ static const uint16_t network[6][32] * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm) +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm) { // Level 1 zmm = cmp_merge( @@ -93,8 +94,8 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm) } // Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm) +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit(reg_t zmm) { // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc .. zmm = cmp_merge( @@ -118,208 +119,4 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm) return zmm; } -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2) -{ - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2); - zmm_t zmm3 = vtype::min(zmm1, zmm2); - zmm_t zmm4 = vtype::max(zmm1, zmm2); - // 2) Recursive half cleaner for each - zmm1 = bitonic_merge_zmm_16bit(zmm3); - zmm2 = bitonic_merge_zmm_16bit(zmm4); -} - -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm) -{ - zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]); - zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4), - vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4), - vtype::max(zmm[0], zmm3r)); - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); - zmm[0] = bitonic_merge_zmm_16bit(zmm0); - zmm[1] = bitonic_merge_zmm_16bit(zmm1); - zmm[2] = bitonic_merge_zmm_16bit(zmm2); - zmm[3] = bitonic_merge_zmm_16bit(zmm3); -} - -template -X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N) -{ - typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF; - typename vtype::zmm_t zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_16bit(zmm)); -} - -template -X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N) -{ - if (N <= 32) { - sort_32_16bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - typename vtype::opmask_t load_mask - = ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF; - zmm_t zmm1 = vtype::loadu(arr); - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32); - zmm1 = sort_zmm_16bit(zmm1); - zmm2 = sort_zmm_16bit(zmm2); - bitonic_merge_two_zmm_16bit(zmm1, zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 32, load_mask, zmm2); -} - -template -X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N) -{ - if (N <= 64) { - sort_64_16bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 32); - opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = combined_mask & 0xFFFFFFFF; - load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF; - } - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96); - zmm[0] = sort_zmm_16bit(zmm[0]); - zmm[1] = sort_zmm_16bit(zmm[1]); - zmm[2] = sort_zmm_16bit(zmm[2]); - zmm[3] = sort_zmm_16bit(zmm[3]); - bitonic_merge_two_zmm_16bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_16bit(zmm[2], zmm[3]); - bitonic_merge_four_zmm_16bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 32, zmm[1]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 96, load_mask2, zmm[3]); -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, - const int64_t left, - const int64_t right) -{ - // median of 32 - int64_t size = (right - left) / 32; - type_t vec_arr[32] = {arr[left], - arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size], - arr[left + 17 * size], - arr[left + 18 * size], - arr[left + 19 * size], - arr[left + 20 * size], - arr[left + 21 * size], - arr[left + 22 * size], - arr[left + 23 * size], - arr[left + 24 * size], - arr[left + 25 * size], - arr[left + 26 * size], - arr[left + 27 * size], - arr[left + 28 * size], - arr[left + 29 * size], - arr[left + 30 * size], - arr[left + 31 * size]}; - typename vtype::zmm_t rand_vec = vtype::loadu(vec_arr); - typename vtype::zmm_t sort = sort_zmm_16bit(rand_vec); - return ((type_t *)&sort)[16]; -} - -template -static void -qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_16bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_16bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( - arr, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - qsort_16bit_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_16bit_(arr, pivot_index, right, max_iters - 1); -} - -template -static void qselect_16bit_(type_t *arr, - int64_t pos, - int64_t left, - int64_t right, - int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_16bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_16bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( - arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_16bit_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_16bit_(arr, pos, pivot_index, right, max_iters - 1); -} - #endif // AVX512_16BIT_COMMON diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 2cdb45e7..13b732d0 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -16,12 +16,14 @@ struct float16 { template <> struct zmm_vector { using type_t = uint16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; + using reg_t = __m512i; + using halfreg_t = __m256i; using opmask_t = __mmask32; static const uint8_t numlanes = 32; + static constexpr int network_sort_threshold = 512; + static constexpr int partition_unroll_factor = 0; - static zmm_t get_network(int index) + static reg_t get_network(int index) { return _mm512_loadu_si512(&network[index - 1][0]); } @@ -33,7 +35,7 @@ struct zmm_vector { { return X86_SIMD_SORT_NEGINFINITYH; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi16(type_max()); } @@ -42,14 +44,14 @@ struct zmm_vector { return _knot_mask32(x); } - static opmask_t ge(zmm_t x, zmm_t y) - { - zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); - zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); - zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); - zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); - zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); - zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); + static opmask_t ge(reg_t x, reg_t y) + { + reg_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); + reg_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); + reg_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); + reg_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); + reg_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); + reg_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); __mmask32 mask_ge = _mm512_cmp_epu16_mask( sign_x, sign_y, _MM_CMPINT_LT); // only greater than @@ -72,37 +74,37 @@ struct zmm_vector { exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); return _kxor_mask32(mask_ge, neg); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_mask_mov_epi16(y, ge(x, y), x); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { // AVX512_VBMI2 return _mm512_mask_compressstoreu_epi16(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { // AVX512BW return _mm512_mask_loadu_epi16(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi16(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi16(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_mask_mov_epi16(x, ge(x, y), y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi16(idx, zmm); } @@ -122,7 +124,7 @@ struct zmm_vector { __m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC); return _mm_extract_epi16(xmm2, 0); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); @@ -130,7 +132,7 @@ struct zmm_vector { float hi_max = _mm512_reduce_max_ps(hi); return float_to_uint16(std::max(lo_max, hi_max)); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); @@ -138,31 +140,46 @@ struct zmm_vector { float hi_max = _mm512_reduce_min_ps(hi); return float_to_uint16(std::min(lo_max, hi_max)); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi16(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_si512(mem, x); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = get_network(4); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_16bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_16bit>(x); + } }; template <> struct zmm_vector { using type_t = int16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; + using reg_t = __m512i; + using halfreg_t = __m256i; using opmask_t = __mmask32; static const uint8_t numlanes = 32; + static constexpr int network_sort_threshold = 512; + static constexpr int partition_unroll_factor = 0; - static zmm_t get_network(int index) + static reg_t get_network(int index) { return _mm512_loadu_si512(&network[index - 1][0]); } @@ -174,7 +191,7 @@ struct zmm_vector { { return X86_SIMD_SORT_MIN_INT16; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi16(type_max()); } @@ -183,84 +200,99 @@ struct zmm_vector { return _knot_mask32(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epi16(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { // AVX512_VBMI2 return _mm512_mask_compressstoreu_epi16(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { // AVX512BW return _mm512_mask_loadu_epi16(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi16(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi16(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epi16(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi16(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); + reg_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); + reg_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); return std::max(lo_max, hi_max); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); + reg_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); + reg_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); return std::min(lo_min, hi_min); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi16(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_si512(mem, x); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = get_network(4); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_16bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_16bit>(x); + } }; template <> struct zmm_vector { using type_t = uint16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; + using reg_t = __m512i; + using halfreg_t = __m256i; using opmask_t = __mmask32; static const uint8_t numlanes = 32; + static constexpr int network_sort_threshold = 512; + static constexpr int partition_unroll_factor = 0; - static zmm_t get_network(int index) + static reg_t get_network(int index) { return _mm512_loadu_si512(&network[index - 1][0]); } @@ -272,7 +304,7 @@ struct zmm_vector { { return 0; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi16(type_max()); } @@ -281,72 +313,85 @@ struct zmm_vector { { return _knot_mask32(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epu16(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_epi16(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi16(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi16(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi16(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epu16(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi16(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); + reg_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); + reg_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); return std::max(lo_max, hi_max); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); + reg_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); + reg_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); return std::min(lo_min, hi_min); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi16(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_si512(mem, x); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = get_network(4); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_16bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_16bit>(x); + } }; template <> @@ -403,51 +448,17 @@ bool is_a_nan(uint16_t elem) return (elem & 0x7c00) == 0x7c00; } -/* Specialized template function for 16-bit qsort_ funcs*/ -template <> -void qsort_>(int16_t *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_16bit_>(arr, left, right, maxiters); -} - -template <> -void qsort_>(uint16_t *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_16bit_>(arr, left, right, maxiters); -} - void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf, uint16_t>( arr, arrsize); - qsort_16bit_, uint16_t>( + qsort_, uint16_t>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); } } -/* Specialized template function for 16-bit qselect_ funcs*/ -template <> -void qselect_>( - int16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_16bit_>(arr, k, left, right, maxiters); -} - -template <> -void qselect_>( - uint16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_16bit_>(arr, k, left, right, maxiters); -} - void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; @@ -455,7 +466,7 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { - qselect_16bit_, uint16_t>( + qselect_, uint16_t>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 054e4b26..fd427c28 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -9,6 +9,7 @@ #define AVX512_QSORT_32BIT #include "avx512-common-qsort.h" +#include "xss-network-qsort.hpp" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic @@ -23,13 +24,21 @@ #define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 #define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm); + +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit(reg_t zmm); + template <> struct zmm_vector { using type_t = int32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; + using reg_t = __m512i; + using halfreg_t = __m256i; using opmask_t = __mmask16; static const uint8_t numlanes = 16; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 2; static type_t type_max() { @@ -39,7 +48,7 @@ struct zmm_vector { { return X86_SIMD_SORT_MIN_INT32; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi32(type_max()); } @@ -48,90 +57,105 @@ struct zmm_vector { { return _mm512_knot(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } template - static ymm_t i64gather(__m512i index, void const *base) + static halfreg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi32(index, base, scale); } - static zmm_t merge(ymm_t y1, ymm_t y2) + static reg_t merge(halfreg_t y1, halfreg_t y2) { - zmm_t z1 = _mm512_castsi256_si512(y1); + reg_t z1 = _mm512_castsi256_si512(y1); return _mm512_inserti32x8(z1, y2, 1); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_epi32(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi32(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi32(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi32(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epi32(x, y); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epi32(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi32(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_epi32(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_epi32(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi32(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_si512(mem, x); } - static ymm_t max(ymm_t x, ymm_t y) + static halfreg_t max(halfreg_t x, halfreg_t y) { return _mm256_max_epi32(x, y); } - static ymm_t min(ymm_t x, ymm_t y) + static halfreg_t min(halfreg_t x, halfreg_t y) { return _mm256_min_epi32(x, y); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_32bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_32bit>(x); + } }; template <> struct zmm_vector { using type_t = uint32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; + using reg_t = __m512i; + using halfreg_t = __m256i; using opmask_t = __mmask16; static const uint8_t numlanes = 16; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 2; static type_t type_max() { @@ -141,99 +165,114 @@ struct zmm_vector { { return 0; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi32(type_max()); } // TODO: this should broadcast bits as is? template - static ymm_t i64gather(__m512i index, void const *base) + static halfreg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi32(index, base, scale); } - static zmm_t merge(ymm_t y1, ymm_t y2) + static reg_t merge(halfreg_t y1, halfreg_t y2) { - zmm_t z1 = _mm512_castsi256_si512(y1); + reg_t z1 = _mm512_castsi256_si512(y1); return _mm512_inserti32x8(z1, y2, 1); } static opmask_t knot_opmask(opmask_t x) { return _mm512_knot(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epu32(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_epi32(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi32(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi32(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi32(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epu32(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi32(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_epu32(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_epu32(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi32(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_si512(mem, x); } - static ymm_t max(ymm_t x, ymm_t y) + static halfreg_t max(halfreg_t x, halfreg_t y) { return _mm256_max_epu32(x, y); } - static ymm_t min(ymm_t x, ymm_t y) + static halfreg_t min(halfreg_t x, halfreg_t y) { return _mm256_min_epu32(x, y); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_32bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_32bit>(x); + } }; template <> struct zmm_vector { using type_t = float; - using zmm_t = __m512; - using ymm_t = __m256; + using reg_t = __m512; + using halfreg_t = __m256; using opmask_t = __mmask16; static const uint8_t numlanes = 16; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 2; static type_t type_max() { @@ -243,7 +282,7 @@ struct zmm_vector { { return -X86_SIMD_SORT_INFINITYF; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_ps(type_max()); } @@ -252,7 +291,7 @@ struct zmm_vector { { return _mm512_knot(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } @@ -261,95 +300,108 @@ struct zmm_vector { return (0x0001 << size) - 0x0001; } template - static opmask_t fpclass(zmm_t x) + static opmask_t fpclass(reg_t x) { return _mm512_fpclass_ps_mask(x, type); } template - static ymm_t i64gather(__m512i index, void const *base) + static halfreg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_ps(index, base, scale); } - static zmm_t merge(ymm_t y1, ymm_t y2) + static reg_t merge(halfreg_t y1, halfreg_t y2) { - zmm_t z1 = _mm512_castsi512_ps( + reg_t z1 = _mm512_castsi512_ps( _mm512_castsi256_si512(_mm256_castps_si256(y1))); return _mm512_insertf32x8(z1, y2, 1); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_ps(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_ps(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_ps(mem, mask, x); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm512_maskz_loadu_ps(mask, mem); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_ps(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_ps(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_ps(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_ps(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_ps(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_ps(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_ps(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_ps(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_ps(mem, x); } - static ymm_t max(ymm_t x, ymm_t y) + static halfreg_t max(halfreg_t x, halfreg_t y) { return _mm256_max_ps(x, y); } - static ymm_t min(ymm_t x, ymm_t y) + static halfreg_t min(halfreg_t x, halfreg_t y) { return _mm256_min_ps(x, y); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_32bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_32bit>(x); + } }; /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_32bit(zmm_t zmm) +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm) { zmm = cmp_merge( zmm, @@ -395,8 +447,8 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_32bit(zmm_t zmm) } // Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_32bit(zmm_t zmm) +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit(reg_t zmm) { // 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. zmm = cmp_merge( @@ -421,334 +473,4 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_32bit(zmm_t zmm) return zmm; } -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_32bit(zmm_t *zmm1, zmm_t *zmm2) -{ - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - *zmm2 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), *zmm2); - zmm_t zmm3 = vtype::min(*zmm1, *zmm2); - zmm_t zmm4 = vtype::max(*zmm1, *zmm2); - // 2) Recursive half cleaner for each - *zmm1 = bitonic_merge_zmm_32bit(zmm3); - *zmm2 = bitonic_merge_zmm_32bit(zmm4); -} - -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_32bit(zmm_t *zmm) -{ - zmm_t zmm2r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[2]); - zmm_t zmm3r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[3]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - zmm_t zmm_t3 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[0], zmm3r)); - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); - zmm[0] = bitonic_merge_zmm_32bit(zmm0); - zmm[1] = bitonic_merge_zmm_32bit(zmm1); - zmm[2] = bitonic_merge_zmm_32bit(zmm2); - zmm[3] = bitonic_merge_zmm_32bit(zmm3); -} - -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_32bit(zmm_t *zmm) -{ - zmm_t zmm4r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[4]); - zmm_t zmm5r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[5]); - zmm_t zmm6r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[6]); - zmm_t zmm7r = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm[7]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r); - zmm_t zmm_t5 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[3], zmm4r)); - zmm_t zmm_t6 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[2], zmm5r)); - zmm_t zmm_t7 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[1], zmm6r)); - zmm_t zmm_t8 = vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), - vtype::max(zmm[0], zmm7r)); - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - zmm[0] = bitonic_merge_zmm_32bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_32bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_32bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_32bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_32bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_32bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_32bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_32bit(zmm_t8); -} - -template -X86_SIMD_SORT_INLINE void sort_16_32bit(type_t *arr, int32_t N) -{ - typename vtype::opmask_t load_mask = (0x0001 << N) - 0x0001; - typename vtype::zmm_t zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_32bit(zmm)); -} - -template -X86_SIMD_SORT_INLINE void sort_32_32bit(type_t *arr, int32_t N) -{ - if (N <= 16) { - sort_16_32bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - zmm_t zmm1 = vtype::loadu(arr); - typename vtype::opmask_t load_mask = (0x0001 << (N - 16)) - 0x0001; - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 16); - zmm1 = sort_zmm_32bit(zmm1); - zmm2 = sort_zmm_32bit(zmm2); - bitonic_merge_two_zmm_32bit(&zmm1, &zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 16, load_mask, zmm2); -} - -template -X86_SIMD_SORT_INLINE void sort_64_32bit(type_t *arr, int32_t N) -{ - if (N <= 32) { - sort_32_32bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 16); - opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF; - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 &= combined_mask & 0xFFFF; - load_mask2 &= (combined_mask >> 16) & 0xFFFF; - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 48); - zmm[0] = sort_zmm_32bit(zmm[0]); - zmm[1] = sort_zmm_32bit(zmm[1]); - zmm[2] = sort_zmm_32bit(zmm[2]); - zmm[3] = sort_zmm_32bit(zmm[3]); - bitonic_merge_two_zmm_32bit(&zmm[0], &zmm[1]); - bitonic_merge_two_zmm_32bit(&zmm[2], &zmm[3]); - bitonic_merge_four_zmm_32bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 16, zmm[1]); - vtype::mask_storeu(arr + 32, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 48, load_mask2, zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void sort_128_32bit(type_t *arr, int32_t N) -{ - if (N <= 64) { - sort_64_32bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[8]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 16); - zmm[2] = vtype::loadu(arr + 32); - zmm[3] = vtype::loadu(arr + 48); - zmm[0] = sort_zmm_32bit(zmm[0]); - zmm[1] = sort_zmm_32bit(zmm[1]); - zmm[2] = sort_zmm_32bit(zmm[2]); - zmm[3] = sort_zmm_32bit(zmm[3]); - opmask_t load_mask1 = 0xFFFF, load_mask2 = 0xFFFF; - opmask_t load_mask3 = 0xFFFF, load_mask4 = 0xFFFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 &= combined_mask & 0xFFFF; - load_mask2 &= (combined_mask >> 16) & 0xFFFF; - load_mask3 &= (combined_mask >> 32) & 0xFFFF; - load_mask4 &= (combined_mask >> 48) & 0xFFFF; - } - zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 80); - zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 96); - zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 112); - zmm[4] = sort_zmm_32bit(zmm[4]); - zmm[5] = sort_zmm_32bit(zmm[5]); - zmm[6] = sort_zmm_32bit(zmm[6]); - zmm[7] = sort_zmm_32bit(zmm[7]); - bitonic_merge_two_zmm_32bit(&zmm[0], &zmm[1]); - bitonic_merge_two_zmm_32bit(&zmm[2], &zmm[3]); - bitonic_merge_two_zmm_32bit(&zmm[4], &zmm[5]); - bitonic_merge_two_zmm_32bit(&zmm[6], &zmm[7]); - bitonic_merge_four_zmm_32bit(zmm); - bitonic_merge_four_zmm_32bit(zmm + 4); - bitonic_merge_eight_zmm_32bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 16, zmm[1]); - vtype::storeu(arr + 32, zmm[2]); - vtype::storeu(arr + 48, zmm[3]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[4]); - vtype::mask_storeu(arr + 80, load_mask2, zmm[5]); - vtype::mask_storeu(arr + 96, load_mask3, zmm[6]); - vtype::mask_storeu(arr + 112, load_mask4, zmm[7]); -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, - const int64_t left, - const int64_t right) -{ - // median of 16 - int64_t size = (right - left) / 16; - using zmm_t = typename vtype::zmm_t; - using ymm_t = typename vtype::ymm_t; - __m512i rand_index1 = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - __m512i rand_index2 = _mm512_set_epi64(left + 9 * size, - left + 10 * size, - left + 11 * size, - left + 12 * size, - left + 13 * size, - left + 14 * size, - left + 15 * size, - left + 16 * size); - ymm_t rand_vec1 - = vtype::template i64gather(rand_index1, arr); - ymm_t rand_vec2 - = vtype::template i64gather(rand_index2, arr); - zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2); - zmm_t sort = sort_zmm_32bit(rand_vec); - // pivot will never be a nan, since there are no nan's! - return ((type_t *)&sort)[8]; -} - -template -static void -qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_32bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_32bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_32bit_(arr, pivot_index, right, max_iters - 1); -} - -template -static void qselect_32bit_(type_t *arr, - int64_t pos, - int64_t left, - int64_t right, - int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_32bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_32bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_32bit_(arr, pos, pivot_index, right, max_iters - 1); -} - -/* Specialized template function for 32-bit qselect_ funcs*/ -template <> -void qselect_>( - int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_32bit_>(arr, k, left, right, maxiters); -} - -template <> -void qselect_>( - uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_32bit_>(arr, k, left, right, maxiters); -} - -template <> -void qselect_>( - float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_32bit_>(arr, k, left, right, maxiters); -} - -/* Specialized template function for 32-bit qsort_ funcs*/ -template <> -void qsort_>(int32_t *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_32bit_>(arr, left, right, maxiters); -} - -template <> -void qsort_>(uint32_t *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_32bit_>(arr, left, right, maxiters); -} - -template <> -void qsort_>(float *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_32bit_>(arr, left, right, maxiters); -} #endif //AVX512_QSORT_32BIT diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 79001d33..e8254ce1 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -65,10 +65,10 @@ void std_argsort(T *arr, int64_t *arg, int64_t left, int64_t right) template X86_SIMD_SORT_INLINE void argsort_8_64bit(type_t *arr, int64_t *arg, int32_t N) { - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; argzmm_t argzmm = argtype::maskz_loadu(load_mask, arg); - zmm_t arrzmm = vtype::template mask_i64gather( + reg_t arrzmm = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm, arr); arrzmm = sort_zmm_64bit(arrzmm, argzmm); argtype::mask_storeu(arg, load_mask, argzmm); @@ -81,12 +81,12 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) argsort_8_64bit(arr, arg, N); return; } - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; argzmm_t argzmm1 = argtype::loadu(arg); argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); - zmm_t arrzmm1 = vtype::template i64gather(argzmm1, arr); - zmm_t arrzmm2 = vtype::template mask_i64gather( + reg_t arrzmm1 = vtype::template i64gather(argzmm1, arr); + reg_t arrzmm2 = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm2, arr); arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); @@ -103,9 +103,9 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) argsort_16_64bit(arr, arg, N); return; } - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; - zmm_t arrzmm[4]; + reg_t arrzmm[4]; argzmm_t argzmm[4]; X86_SIMD_SORT_UNROLL_LOOP(2) @@ -146,9 +146,9 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) argsort_32_64bit(arr, arg, N); return; } - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; - zmm_t arrzmm[8]; + reg_t arrzmm[8]; argzmm_t argzmm[8]; X86_SIMD_SORT_UNROLL_LOOP(4) @@ -198,9 +198,9 @@ X86_SIMD_SORT_UNROLL_LOOP(4) // argsort_64_64bit(arr, arg, N); // return; // } -// using zmm_t = typename vtype::zmm_t; +// using reg_t = typename vtype::reg_t; // using opmask_t = typename vtype::opmask_t; -// zmm_t arrzmm[16]; +// reg_t arrzmm[16]; // argzmm_t argzmm[16]; // //X86_SIMD_SORT_UNROLL_LOOP(8) @@ -256,7 +256,7 @@ type_t get_pivot_64bit(type_t *arr, if (right - left >= vtype::numlanes) { // median of 8 int64_t size = (right - left) / 8; - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; // TODO: Use gather here too: __m512i rand_index = _mm512_set_epi64(arg[left + size], arg[left + 2 * size], @@ -266,10 +266,10 @@ type_t get_pivot_64bit(type_t *arr, arg[left + 6 * size], arg[left + 7 * size], arg[left + 8 * size]); - zmm_t rand_vec + reg_t rand_vec = vtype::template i64gather(rand_index, arr); // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(rand_vec); + reg_t sort = sort_zmm_64bit(rand_vec); return ((type_t *)&sort)[4]; } else { diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index fbd4a88f..0353507e 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -19,10 +19,16 @@ #define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 #define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); + +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm); + template <> struct ymm_vector { using type_t = float; - using zmm_t = __m256; + using reg_t = __m256; using zmmi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -35,7 +41,7 @@ struct ymm_vector { { return -X86_SIMD_SORT_INFINITYF; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm256_set1_ps(type_max()); } @@ -53,15 +59,15 @@ struct ymm_vector { { return _knot_mask8(x); } - static opmask_t le(zmm_t x, zmm_t y) + static opmask_t le(reg_t x, reg_t y) { return _mm256_cmp_ps_mask(x, y, _CMP_LE_OQ); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm256_cmp_ps_mask(x, y, _CMP_GE_OQ); } - static opmask_t eq(zmm_t x, zmm_t y) + static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ); } @@ -70,58 +76,58 @@ struct ymm_vector { return (0x01 << size) - 0x01; } template - static opmask_t fpclass(zmm_t x) + static opmask_t fpclass(reg_t x) { return _mm256_fpclass_ps_mask(x, type); } template - static zmm_t - mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_ps(src, mask, index, base, scale); } template - static zmm_t i64gather(__m512i index, void const *base) + static reg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_ps(index, base, scale); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm256_loadu_ps((float *)mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm256_max_ps(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm256_mask_compressstoreu_ps(mem, mask, x); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskz_loadu_ps(mask, mem); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm256_mask_loadu_ps(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm256_mask_mov_ps(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm256_mask_storeu_ps(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm256_min_ps(x, y); } - static zmm_t permutexvar(__m256i idx, zmm_t zmm) + static reg_t permutexvar(__m256i idx, reg_t zmm) { return _mm256_permutexvar_ps(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { __m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps(v, 1)); @@ -131,7 +137,7 @@ struct ymm_vector { v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); return _mm_cvtss_f32(v32); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { __m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps(v, 1)); @@ -141,12 +147,12 @@ struct ymm_vector { v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); return _mm_cvtss_f32(v32); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm256_set1_ps(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { /* Hack!: have to make shuffles within 128-bit lanes work for both * 32-bit and 64-bit */ @@ -158,7 +164,7 @@ struct ymm_vector { // return _mm256_shuffle_ps(zmm, zmm, mask); //} } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { _mm256_storeu_ps((float *)mem, x); } @@ -166,7 +172,7 @@ struct ymm_vector { template <> struct ymm_vector { using type_t = uint32_t; - using zmm_t = __m256i; + using reg_t = __m256i; using zmmi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -179,7 +185,7 @@ struct ymm_vector { { return 0; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm256_set1_epi32(type_max()); } @@ -197,66 +203,66 @@ struct ymm_vector { { return _knot_mask8(x); } - static opmask_t le(zmm_t x, zmm_t y) + static opmask_t le(reg_t x, reg_t y) { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_LE); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); } - static opmask_t eq(zmm_t x, zmm_t y) + static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_EQ); } template - static zmm_t - mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } template - static zmm_t i64gather(__m512i index, void const *base) + static reg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi32(index, base, scale); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm256_loadu_si256((__m256i *)mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm256_max_epu32(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm256_mask_compressstoreu_epi32(mem, mask, x); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskz_loadu_epi32(mask, mem); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm256_mask_loadu_epi32(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm256_mask_mov_epi32(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm256_mask_storeu_epi32(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm256_min_epu32(x, y); } - static zmm_t permutexvar(__m256i idx, zmm_t zmm) + static reg_t permutexvar(__m256i idx, reg_t zmm) { return _mm256_permutexvar_epi32(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { __m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); @@ -266,7 +272,7 @@ struct ymm_vector { v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { __m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); @@ -276,18 +282,18 @@ struct ymm_vector { v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm256_set1_epi32(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { /* Hack!: have to make shuffles within 128-bit lanes work for both * 32-bit and 64-bit */ return _mm256_shuffle_epi32(zmm, 0b10110001); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { _mm256_storeu_si256((__m256i *)mem, x); } @@ -295,7 +301,7 @@ struct ymm_vector { template <> struct ymm_vector { using type_t = int32_t; - using zmm_t = __m256i; + using reg_t = __m256i; using zmmi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -308,7 +314,7 @@ struct ymm_vector { { return X86_SIMD_SORT_MIN_INT32; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm256_set1_epi32(type_max()); } // TODO: this should broadcast bits as is? @@ -326,66 +332,66 @@ struct ymm_vector { { return _knot_mask8(x); } - static opmask_t le(zmm_t x, zmm_t y) + static opmask_t le(reg_t x, reg_t y) { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_LE); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } - static opmask_t eq(zmm_t x, zmm_t y) + static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_EQ); } template - static zmm_t - mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } template - static zmm_t i64gather(__m512i index, void const *base) + static reg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi32(index, base, scale); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm256_loadu_si256((__m256i *)mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm256_max_epi32(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm256_mask_compressstoreu_epi32(mem, mask, x); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskz_loadu_epi32(mask, mem); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm256_mask_loadu_epi32(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm256_mask_mov_epi32(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm256_mask_storeu_epi32(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm256_min_epi32(x, y); } - static zmm_t permutexvar(__m256i idx, zmm_t zmm) + static reg_t permutexvar(__m256i idx, reg_t zmm) { return _mm256_permutexvar_epi32(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { __m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); @@ -395,7 +401,7 @@ struct ymm_vector { v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { __m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); @@ -405,18 +411,18 @@ struct ymm_vector { v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); return (type_t)_mm_cvtsi128_si32(v32); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm256_set1_epi32(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { /* Hack!: have to make shuffles within 128-bit lanes work for both * 32-bit and 64-bit */ return _mm256_shuffle_epi32(zmm, 0b10110001); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { _mm256_storeu_si256((__m256i *)mem, x); } @@ -424,11 +430,13 @@ struct ymm_vector { template <> struct zmm_vector { using type_t = int64_t; - using zmm_t = __m512i; + using reg_t = __m512i; using zmmi_t = __m512i; - using ymm_t = __m512i; + using halfreg_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 8; static type_t type_max() { @@ -438,7 +446,7 @@ struct zmm_vector { { return X86_SIMD_SORT_MIN_INT64; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi64(type_max()); } // TODO: this should broadcast bits as is? @@ -456,97 +464,112 @@ struct zmm_vector { { return _knot_mask8(x); } - static opmask_t le(zmm_t x, zmm_t y) + static opmask_t le(reg_t x, reg_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_LE); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); } - static opmask_t eq(zmm_t x, zmm_t y) + static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); } template - static zmm_t - mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } template - static zmm_t i64gather(__m512i index, void const *base) + static reg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi64(index, base, scale); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epi64(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_epi64(mem, mask, x); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm512_maskz_loadu_epi64(mask, mem); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi64(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi64(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi64(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epi64(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi64(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_epi64(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_epi64(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi64(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { __m512d temp = _mm512_castsi512_pd(zmm); return _mm512_castpd_si512( _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { _mm512_storeu_si512(mem, x); } + static reg_t reverse(reg_t zmm) + { + const zmmi_t rev_index = seti(NETWORK_64BIT_2); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_64bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_64bit>(x); + } }; template <> struct zmm_vector { using type_t = uint64_t; - using zmm_t = __m512i; + using reg_t = __m512i; using zmmi_t = __m512i; - using ymm_t = __m512i; + using halfreg_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 8; static type_t type_max() { @@ -556,7 +579,7 @@ struct zmm_vector { { return 0; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_epi64(type_max()); } @@ -567,13 +590,13 @@ struct zmm_vector { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } template - static zmm_t - mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } template - static zmm_t i64gather(__m512i index, void const *base) + static reg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi64(index, base, scale); } @@ -581,78 +604,93 @@ struct zmm_vector { { return _knot_mask8(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); } - static opmask_t eq(zmm_t x, zmm_t y) + static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_epu64(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_epi64(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi64(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_epi64(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi64(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_epu64(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_epi64(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_epu64(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_epu64(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_epi64(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { __m512d temp = _mm512_castsi512_pd(zmm); return _mm512_castpd_si512( _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { _mm512_storeu_si512(mem, x); } + static reg_t reverse(reg_t zmm) + { + const zmmi_t rev_index = seti(NETWORK_64BIT_2); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_64bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_64bit>(x); + } }; template <> struct zmm_vector { using type_t = double; - using zmm_t = __m512d; + using reg_t = __m512d; using zmmi_t = __m512i; - using ymm_t = __m512d; + using halfreg_t = __m512d; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 8; static type_t type_max() { @@ -662,7 +700,7 @@ struct zmm_vector { { return -X86_SIMD_SORT_INFINITY; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_pd(type_max()); } @@ -673,7 +711,7 @@ struct zmm_vector { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm512_maskz_loadu_pd(mask, mem); } @@ -681,11 +719,11 @@ struct zmm_vector { { return _knot_mask8(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); } - static opmask_t eq(zmm_t x, zmm_t y) + static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); } @@ -694,82 +732,95 @@ struct zmm_vector { return (0x01 << size) - 0x01; } template - static opmask_t fpclass(zmm_t x) + static opmask_t fpclass(reg_t x) { return _mm512_fpclass_pd_mask(x, type); } template - static zmm_t - mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_pd(src, mask, index, base, scale); } template - static zmm_t i64gather(__m512i index, void const *base) + static reg_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_pd(index, base, scale); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_pd(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_pd(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_compressstoreu_pd(mem, mask, x); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_pd(x, mask, mem); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_mask_mov_pd(x, mask, y); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_pd(mem, mask, x); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_pd(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_pd(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_pd(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_pd(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_pd(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { _mm512_storeu_pd(mem, x); } + static reg_t reverse(reg_t zmm) + { + const zmmi_t rev_index = seti(NETWORK_64BIT_2); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_64bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_64bit>(x); + } }; /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) { const typename vtype::zmmi_t rev_index = vtype::seti(NETWORK_64BIT_2); zmm = cmp_merge( @@ -786,26 +837,25 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) return zmm; } -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - const int64_t left, - const int64_t right) +// Assumes zmm is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t zmm) { - // median of 8 - int64_t size = (right - left) / 8; - using zmm_t = typename vtype::zmm_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); - // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; + + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + zmm = cmp_merge( + zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), zmm), + 0xF0); + // 2) half_cleaner[4] + zmm = cmp_merge( + zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), + 0xCC); + // 3) half_cleaner[1] + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + return zmm; } #endif diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp index b930a42b..bae8fa6d 100644 --- a/src/avx512-64bit-keyvalue-networks.hpp +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -1,9 +1,9 @@ template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) + typename reg_t = typename vtype1::reg_t, + typename index_type = typename vtype2::reg_t> +X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t key_zmm, index_type &index_zmm) { const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); @@ -48,9 +48,9 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) // Assumes zmm is bitonic and performs a recursive half cleaner template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, + typename reg_t = typename vtype1::reg_t, + typename index_type = typename vtype2::reg_t> +X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t key_zmm, index_type &index_zmm) { @@ -80,10 +80,10 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, - zmm_t &key_zmm2, + typename reg_t = typename vtype1::reg_t, + typename index_type = typename vtype2::reg_t> +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(reg_t &key_zmm1, + reg_t &key_zmm2, index_type &index_zmm1, index_type &index_zmm2) { @@ -93,8 +93,8 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); - zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); - zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); + reg_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); + reg_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1); @@ -115,23 +115,23 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, // half cleaner template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, + typename reg_t = typename vtype1::reg_t, + typename index_type = typename vtype2::reg_t> +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(reg_t *key_zmm, index_type *index_zmm) { const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); // 1) First step of a merging network - zmm_t key_zmm2r = vtype1::permutexvar(rev_index1, key_zmm[2]); - zmm_t key_zmm3r = vtype1::permutexvar(rev_index1, key_zmm[3]); + reg_t key_zmm2r = vtype1::permutexvar(rev_index1, key_zmm[2]); + reg_t key_zmm3r = vtype1::permutexvar(rev_index1, key_zmm[3]); index_type index_zmm2r = vtype2::permutexvar(rev_index2, index_zmm[2]); index_type index_zmm3r = vtype2::permutexvar(rev_index2, index_zmm[3]); - zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); - zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); - zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); + reg_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); + reg_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); + reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); + reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]); typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]); @@ -146,15 +146,15 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); // 2) Recursive half clearer: 16 - zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); - zmm_t key_zmm_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); + reg_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_zmm_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); index_type index_zmm_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); index_type index_zmm_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); - zmm_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); - zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); + reg_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); + reg_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); + reg_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); + reg_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); movmask1 = vtype1::eq(key_zmm0, key_zmm_t1); movmask2 = vtype1::eq(key_zmm2, key_zmm_t3); @@ -181,31 +181,31 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, + typename reg_t = typename vtype1::reg_t, + typename index_type = typename vtype2::reg_t> +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(reg_t *key_zmm, index_type *index_zmm) { const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - zmm_t key_zmm4r = vtype1::permutexvar(rev_index1, key_zmm[4]); - zmm_t key_zmm5r = vtype1::permutexvar(rev_index1, key_zmm[5]); - zmm_t key_zmm6r = vtype1::permutexvar(rev_index1, key_zmm[6]); - zmm_t key_zmm7r = vtype1::permutexvar(rev_index1, key_zmm[7]); + reg_t key_zmm4r = vtype1::permutexvar(rev_index1, key_zmm[4]); + reg_t key_zmm5r = vtype1::permutexvar(rev_index1, key_zmm[5]); + reg_t key_zmm6r = vtype1::permutexvar(rev_index1, key_zmm[6]); + reg_t key_zmm7r = vtype1::permutexvar(rev_index1, key_zmm[7]); index_type index_zmm4r = vtype2::permutexvar(rev_index2, index_zmm[4]); index_type index_zmm5r = vtype2::permutexvar(rev_index2, index_zmm[5]); index_type index_zmm6r = vtype2::permutexvar(rev_index2, index_zmm[6]); index_type index_zmm7r = vtype2::permutexvar(rev_index2, index_zmm[7]); - zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); - zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); - zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); - zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); + reg_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); + reg_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); + reg_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); + reg_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); - zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); - zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); - zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); + reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); + reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); + reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); + reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]); typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]); @@ -229,10 +229,10 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, index_type index_zmm_m4 = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); - zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); - zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); - zmm_t key_zmm_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); - zmm_t key_zmm_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); + reg_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); + reg_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); + reg_t key_zmm_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_zmm_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); index_type index_zmm_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); index_type index_zmm_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); index_type index_zmm_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); @@ -275,21 +275,21 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, + typename reg_t = typename vtype1::reg_t, + typename index_type = typename vtype2::reg_t> +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(reg_t *key_zmm, index_type *index_zmm) { const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - zmm_t key_zmm8r = vtype1::permutexvar(rev_index1, key_zmm[8]); - zmm_t key_zmm9r = vtype1::permutexvar(rev_index1, key_zmm[9]); - zmm_t key_zmm10r = vtype1::permutexvar(rev_index1, key_zmm[10]); - zmm_t key_zmm11r = vtype1::permutexvar(rev_index1, key_zmm[11]); - zmm_t key_zmm12r = vtype1::permutexvar(rev_index1, key_zmm[12]); - zmm_t key_zmm13r = vtype1::permutexvar(rev_index1, key_zmm[13]); - zmm_t key_zmm14r = vtype1::permutexvar(rev_index1, key_zmm[14]); - zmm_t key_zmm15r = vtype1::permutexvar(rev_index1, key_zmm[15]); + reg_t key_zmm8r = vtype1::permutexvar(rev_index1, key_zmm[8]); + reg_t key_zmm9r = vtype1::permutexvar(rev_index1, key_zmm[9]); + reg_t key_zmm10r = vtype1::permutexvar(rev_index1, key_zmm[10]); + reg_t key_zmm11r = vtype1::permutexvar(rev_index1, key_zmm[11]); + reg_t key_zmm12r = vtype1::permutexvar(rev_index1, key_zmm[12]); + reg_t key_zmm13r = vtype1::permutexvar(rev_index1, key_zmm[13]); + reg_t key_zmm14r = vtype1::permutexvar(rev_index1, key_zmm[14]); + reg_t key_zmm15r = vtype1::permutexvar(rev_index1, key_zmm[15]); index_type index_zmm8r = vtype2::permutexvar(rev_index2, index_zmm[8]); index_type index_zmm9r = vtype2::permutexvar(rev_index2, index_zmm[9]); @@ -300,23 +300,23 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_type index_zmm14r = vtype2::permutexvar(rev_index2, index_zmm[14]); index_type index_zmm15r = vtype2::permutexvar(rev_index2, index_zmm[15]); - zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); - zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); - zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm13r); - zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm12r); - zmm_t key_zmm_t5 = vtype1::min(key_zmm[4], key_zmm11r); - zmm_t key_zmm_t6 = vtype1::min(key_zmm[5], key_zmm10r); - zmm_t key_zmm_t7 = vtype1::min(key_zmm[6], key_zmm9r); - zmm_t key_zmm_t8 = vtype1::min(key_zmm[7], key_zmm8r); - - zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); - zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); - zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); - zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); - zmm_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); - zmm_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); - zmm_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); + reg_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); + reg_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); + reg_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm13r); + reg_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm12r); + reg_t key_zmm_t5 = vtype1::min(key_zmm[4], key_zmm11r); + reg_t key_zmm_t6 = vtype1::min(key_zmm[5], key_zmm10r); + reg_t key_zmm_t7 = vtype1::min(key_zmm[6], key_zmm9r); + reg_t key_zmm_t8 = vtype1::min(key_zmm[7], key_zmm8r); + + reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); + reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); + reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); + reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); + reg_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); + reg_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); + reg_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); + reg_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); index_type index_zmm_t1 = vtype2::mask_mov( index_zmm15r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); @@ -352,14 +352,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_type index_zmm_m8 = vtype2::mask_mov( index_zmm[7], vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); - zmm_t key_zmm_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); - zmm_t key_zmm_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); - zmm_t key_zmm_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); - zmm_t key_zmm_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); - zmm_t key_zmm_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); - zmm_t key_zmm_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); - zmm_t key_zmm_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); - zmm_t key_zmm_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); + reg_t key_zmm_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); + reg_t key_zmm_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); + reg_t key_zmm_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); + reg_t key_zmm_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); + reg_t key_zmm_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); + reg_t key_zmm_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); + reg_t key_zmm_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_zmm_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); index_type index_zmm_t9 = vtype2::permutexvar(rev_index2, index_zmm_m8); index_type index_zmm_t10 = vtype2::permutexvar(rev_index2, index_zmm_m7); index_type index_zmm_t11 = vtype2::permutexvar(rev_index2, index_zmm_m6); diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 16f8d354..05a69c87 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -19,10 +19,10 @@ X86_SIMD_SORT_INLINE void sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) { typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype1::zmm_t key_zmm + typename vtype1::reg_t key_zmm = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); - typename vtype2::zmm_t index_zmm + typename vtype2::reg_t index_zmm = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); vtype1::mask_storeu(keys, load_mask, @@ -41,13 +41,13 @@ sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) sort_8_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype1::zmm_t; - using index_type = typename vtype2::zmm_t; + using reg_t = typename vtype1::reg_t; + using index_type = typename vtype2::reg_t; typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t key_zmm1 = vtype1::loadu(keys); - zmm_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); + reg_t key_zmm1 = vtype1::loadu(keys); + reg_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); index_type index_zmm1 = vtype2::loadu(indexes); index_type index_zmm2 @@ -76,10 +76,10 @@ sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) sort_16_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype1::zmm_t; + using reg_t = typename vtype1::reg_t; using opmask_t = typename vtype2::opmask_t; - using index_type = typename vtype2::zmm_t; - zmm_t key_zmm[4]; + using index_type = typename vtype2::reg_t; + reg_t key_zmm[4]; index_type index_zmm[4]; key_zmm[0] = vtype1::loadu(keys); @@ -134,10 +134,10 @@ sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) sort_32_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype1::zmm_t; + using reg_t = typename vtype1::reg_t; using opmask_t = typename vtype1::opmask_t; - using index_type = typename vtype2::zmm_t; - zmm_t key_zmm[8]; + using index_type = typename vtype2::reg_t; + reg_t key_zmm[8]; index_type index_zmm[8]; key_zmm[0] = vtype1::loadu(keys); @@ -222,10 +222,10 @@ sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) sort_64_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype1::zmm_t; - using index_type = typename vtype2::zmm_t; + using reg_t = typename vtype1::reg_t; + using index_type = typename vtype2::reg_t; using opmask_t = typename vtype1::opmask_t; - zmm_t key_zmm[16]; + reg_t key_zmm[16]; index_type index_zmm[16]; key_zmm[0] = vtype1::loadu(keys); @@ -424,7 +424,7 @@ void qsort_64bit_(type1_t *keys, return; } - type1_t pivot = get_pivot_64bit(keys, left, right); + type1_t pivot = get_pivot(keys, left, right); type1_t smallest = vtype1::type_max(); type1_t biggest = vtype1::type_min(); int64_t pivot_index = partition_avx512( diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 626e672e..2dae1622 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -8,828 +8,6 @@ #define AVX512_QSORT_64BIT #include "avx512-64bit-common.h" +#include "xss-network-qsort.hpp" -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) -{ - - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), zmm), - 0xF0); - // 2) half_cleaner[4] - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), - 0xCC); - // 3) half_cleaner[1] - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - return zmm; -} -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm_t &zmm2) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - zmm2 = vtype::permutexvar(rev_index, zmm2); - zmm_t zmm3 = vtype::min(zmm1, zmm2); - zmm_t zmm4 = vtype::max(zmm1, zmm2); - // 2) Recursive half cleaner for each - zmm1 = bitonic_merge_zmm_64bit(zmm3); - zmm2 = bitonic_merge_zmm_64bit(zmm4); -} -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network - zmm_t zmm2r = vtype::permutexvar(rev_index, zmm[2]); - zmm_t zmm3r = vtype::permutexvar(rev_index, zmm[3]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - // 2) Recursive half clearer: 16 - zmm_t zmm_t3 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm3r)); - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); - zmm[0] = bitonic_merge_zmm_64bit(zmm0); - zmm[1] = bitonic_merge_zmm_64bit(zmm1); - zmm[2] = bitonic_merge_zmm_64bit(zmm2); - zmm[3] = bitonic_merge_zmm_64bit(zmm3); -} -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t zmm4r = vtype::permutexvar(rev_index, zmm[4]); - zmm_t zmm5r = vtype::permutexvar(rev_index, zmm[5]); - zmm_t zmm6r = vtype::permutexvar(rev_index, zmm[6]); - zmm_t zmm7r = vtype::permutexvar(rev_index, zmm[7]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm7r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm6r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm5r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm4r); - zmm_t zmm_t5 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm4r)); - zmm_t zmm_t6 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm5r)); - zmm_t zmm_t7 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm6r)); - zmm_t zmm_t8 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm7r)); - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); -} -template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t zmm8r = vtype::permutexvar(rev_index, zmm[8]); - zmm_t zmm9r = vtype::permutexvar(rev_index, zmm[9]); - zmm_t zmm10r = vtype::permutexvar(rev_index, zmm[10]); - zmm_t zmm11r = vtype::permutexvar(rev_index, zmm[11]); - zmm_t zmm12r = vtype::permutexvar(rev_index, zmm[12]); - zmm_t zmm13r = vtype::permutexvar(rev_index, zmm[13]); - zmm_t zmm14r = vtype::permutexvar(rev_index, zmm[14]); - zmm_t zmm15r = vtype::permutexvar(rev_index, zmm[15]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm15r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm14r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm13r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm12r); - zmm_t zmm_t5 = vtype::min(zmm[4], zmm11r); - zmm_t zmm_t6 = vtype::min(zmm[5], zmm10r); - zmm_t zmm_t7 = vtype::min(zmm[6], zmm9r); - zmm_t zmm_t8 = vtype::min(zmm[7], zmm8r); - zmm_t zmm_t9 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm8r)); - zmm_t zmm_t10 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm9r)); - zmm_t zmm_t11 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm10r)); - zmm_t zmm_t12 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm11r)); - zmm_t zmm_t13 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm12r)); - zmm_t zmm_t14 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm13r)); - zmm_t zmm_t15 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm14r)); - zmm_t zmm_t16 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm15r)); - // Recusive half clear 16 zmm regs - COEX(zmm_t1, zmm_t5); - COEX(zmm_t2, zmm_t6); - COEX(zmm_t3, zmm_t7); - COEX(zmm_t4, zmm_t8); - COEX(zmm_t9, zmm_t13); - COEX(zmm_t10, zmm_t14); - COEX(zmm_t11, zmm_t15); - COEX(zmm_t12, zmm_t16); - // - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t9, zmm_t11); - COEX(zmm_t10, zmm_t12); - COEX(zmm_t13, zmm_t15); - COEX(zmm_t14, zmm_t16); - // - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - COEX(zmm_t9, zmm_t10); - COEX(zmm_t11, zmm_t12); - COEX(zmm_t13, zmm_t14); - COEX(zmm_t15, zmm_t16); - // - zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); - zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); - zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); - zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); - zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); - zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); - zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); - zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); - zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); -} - -template -X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]); - zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]); - zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]); - zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]); - zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]); - zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]); - zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]); - zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]); - zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]); - zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]); - zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]); - zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]); - zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]); - zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]); - zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]); - zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]); - zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r); - zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r); - zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r); - zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r); - zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r); - zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r); - zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r); - zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r); - zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r); - zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r); - zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r); - zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r); - zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r); - zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r); - zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r); - zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r); - zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r)); - zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r)); - zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r)); - zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r)); - zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r)); - zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r)); - zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r)); - zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r)); - zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r)); - zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r)); - zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r)); - zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r)); - zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r)); - zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r)); - zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r)); - zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r)); - // Recusive half clear 16 zmm regs - COEX(zmm_t1, zmm_t9); - COEX(zmm_t2, zmm_t10); - COEX(zmm_t3, zmm_t11); - COEX(zmm_t4, zmm_t12); - COEX(zmm_t5, zmm_t13); - COEX(zmm_t6, zmm_t14); - COEX(zmm_t7, zmm_t15); - COEX(zmm_t8, zmm_t16); - COEX(zmm_t17, zmm_t25); - COEX(zmm_t18, zmm_t26); - COEX(zmm_t19, zmm_t27); - COEX(zmm_t20, zmm_t28); - COEX(zmm_t21, zmm_t29); - COEX(zmm_t22, zmm_t30); - COEX(zmm_t23, zmm_t31); - COEX(zmm_t24, zmm_t32); - // - COEX(zmm_t1, zmm_t5); - COEX(zmm_t2, zmm_t6); - COEX(zmm_t3, zmm_t7); - COEX(zmm_t4, zmm_t8); - COEX(zmm_t9, zmm_t13); - COEX(zmm_t10, zmm_t14); - COEX(zmm_t11, zmm_t15); - COEX(zmm_t12, zmm_t16); - COEX(zmm_t17, zmm_t21); - COEX(zmm_t18, zmm_t22); - COEX(zmm_t19, zmm_t23); - COEX(zmm_t20, zmm_t24); - COEX(zmm_t25, zmm_t29); - COEX(zmm_t26, zmm_t30); - COEX(zmm_t27, zmm_t31); - COEX(zmm_t28, zmm_t32); - // - COEX(zmm_t1, zmm_t3); - COEX(zmm_t2, zmm_t4); - COEX(zmm_t5, zmm_t7); - COEX(zmm_t6, zmm_t8); - COEX(zmm_t9, zmm_t11); - COEX(zmm_t10, zmm_t12); - COEX(zmm_t13, zmm_t15); - COEX(zmm_t14, zmm_t16); - COEX(zmm_t17, zmm_t19); - COEX(zmm_t18, zmm_t20); - COEX(zmm_t21, zmm_t23); - COEX(zmm_t22, zmm_t24); - COEX(zmm_t25, zmm_t27); - COEX(zmm_t26, zmm_t28); - COEX(zmm_t29, zmm_t31); - COEX(zmm_t30, zmm_t32); - // - COEX(zmm_t1, zmm_t2); - COEX(zmm_t3, zmm_t4); - COEX(zmm_t5, zmm_t6); - COEX(zmm_t7, zmm_t8); - COEX(zmm_t9, zmm_t10); - COEX(zmm_t11, zmm_t12); - COEX(zmm_t13, zmm_t14); - COEX(zmm_t15, zmm_t16); - COEX(zmm_t17, zmm_t18); - COEX(zmm_t19, zmm_t20); - COEX(zmm_t21, zmm_t22); - COEX(zmm_t23, zmm_t24); - COEX(zmm_t25, zmm_t26); - COEX(zmm_t27, zmm_t28); - COEX(zmm_t29, zmm_t30); - COEX(zmm_t31, zmm_t32); - // - zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); - zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); - zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); - zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); - zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); - zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); - zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); - zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); - zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); - zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); - zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); - zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); - zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); - zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); - zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); - zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); - zmm[16] = bitonic_merge_zmm_64bit(zmm_t17); - zmm[17] = bitonic_merge_zmm_64bit(zmm_t18); - zmm[18] = bitonic_merge_zmm_64bit(zmm_t19); - zmm[19] = bitonic_merge_zmm_64bit(zmm_t20); - zmm[20] = bitonic_merge_zmm_64bit(zmm_t21); - zmm[21] = bitonic_merge_zmm_64bit(zmm_t22); - zmm[22] = bitonic_merge_zmm_64bit(zmm_t23); - zmm[23] = bitonic_merge_zmm_64bit(zmm_t24); - zmm[24] = bitonic_merge_zmm_64bit(zmm_t25); - zmm[25] = bitonic_merge_zmm_64bit(zmm_t26); - zmm[26] = bitonic_merge_zmm_64bit(zmm_t27); - zmm[27] = bitonic_merge_zmm_64bit(zmm_t28); - zmm[28] = bitonic_merge_zmm_64bit(zmm_t29); - zmm[29] = bitonic_merge_zmm_64bit(zmm_t30); - zmm[30] = bitonic_merge_zmm_64bit(zmm_t31); - zmm[31] = bitonic_merge_zmm_64bit(zmm_t32); -} - -template -X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) -{ - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype::zmm_t zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_64bit(zmm)); -} - -template -X86_SIMD_SORT_INLINE void sort_16_64bit(type_t *arr, int32_t N) -{ - if (N <= 8) { - sort_8_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - zmm_t zmm1 = vtype::loadu(arr); - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 8); - zmm1 = sort_zmm_64bit(zmm1); - zmm2 = sort_zmm_64bit(zmm2); - bitonic_merge_two_zmm_64bit(zmm1, zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 8, load_mask, zmm2); -} - -template -X86_SIMD_SORT_INLINE void sort_32_64bit(type_t *arr, int32_t N) -{ - if (N <= 16) { - sort_16_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 16); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 24); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_four_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::mask_storeu(arr + 16, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 24, load_mask2, zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void sort_64_64bit(type_t *arr, int32_t N) -{ - if (N <= 32) { - sort_32_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[8]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - // N-32 >= 1 - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); - zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 40); - zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 48); - zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 56); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_eight_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::mask_storeu(arr + 32, load_mask1, zmm[4]); - vtype::mask_storeu(arr + 40, load_mask2, zmm[5]); - vtype::mask_storeu(arr + 48, load_mask3, zmm[6]); - vtype::mask_storeu(arr + 56, load_mask4, zmm[7]); -} - -template -X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) -{ - if (N <= 64) { - sort_64_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[16]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[4] = vtype::loadu(arr + 32); - zmm[5] = vtype::loadu(arr + 40); - zmm[6] = vtype::loadu(arr + 48); - zmm[7] = vtype::loadu(arr + 56); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - } - zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 72); - zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 80); - zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 88); - zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 96); - zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 104); - zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 112); - zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 120); - zmm[8] = sort_zmm_64bit(zmm[8]); - zmm[9] = sort_zmm_64bit(zmm[9]); - zmm[10] = sort_zmm_64bit(zmm[10]); - zmm[11] = sort_zmm_64bit(zmm[11]); - zmm[12] = sort_zmm_64bit(zmm[12]); - zmm[13] = sort_zmm_64bit(zmm[13]); - zmm[14] = sort_zmm_64bit(zmm[14]); - zmm[15] = sort_zmm_64bit(zmm[15]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); - bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); - bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); - bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_four_zmm_64bit(zmm + 8); - bitonic_merge_four_zmm_64bit(zmm + 12); - bitonic_merge_eight_zmm_64bit(zmm); - bitonic_merge_eight_zmm_64bit(zmm + 8); - bitonic_merge_sixteen_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::storeu(arr + 32, zmm[4]); - vtype::storeu(arr + 40, zmm[5]); - vtype::storeu(arr + 48, zmm[6]); - vtype::storeu(arr + 56, zmm[7]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[8]); - vtype::mask_storeu(arr + 72, load_mask2, zmm[9]); - vtype::mask_storeu(arr + 80, load_mask3, zmm[10]); - vtype::mask_storeu(arr + 88, load_mask4, zmm[11]); - vtype::mask_storeu(arr + 96, load_mask5, zmm[12]); - vtype::mask_storeu(arr + 104, load_mask6, zmm[13]); - vtype::mask_storeu(arr + 112, load_mask7, zmm[14]); - vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); -} - -template -X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) -{ - if (N <= 128) { - sort_128_64bit(arr, N); - return; - } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - zmm_t zmm[32]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[4] = vtype::loadu(arr + 32); - zmm[5] = vtype::loadu(arr + 40); - zmm[6] = vtype::loadu(arr + 48); - zmm[7] = vtype::loadu(arr + 56); - zmm[8] = vtype::loadu(arr + 64); - zmm[9] = vtype::loadu(arr + 72); - zmm[10] = vtype::loadu(arr + 80); - zmm[11] = vtype::loadu(arr + 88); - zmm[12] = vtype::loadu(arr + 96); - zmm[13] = vtype::loadu(arr + 104); - zmm[14] = vtype::loadu(arr + 112); - zmm[15] = vtype::loadu(arr + 120); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - zmm[8] = sort_zmm_64bit(zmm[8]); - zmm[9] = sort_zmm_64bit(zmm[9]); - zmm[10] = sort_zmm_64bit(zmm[10]); - zmm[11] = sort_zmm_64bit(zmm[11]); - zmm[12] = sort_zmm_64bit(zmm[12]); - zmm[13] = sort_zmm_64bit(zmm[13]); - zmm[14] = sort_zmm_64bit(zmm[14]); - zmm[15] = sort_zmm_64bit(zmm[15]); - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF; - opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF; - opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF; - opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF; - if (N != 256) { - uint64_t combined_mask; - if (N < 192) { - combined_mask = (0x1ull << (N - 128)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - load_mask9 = 0x00; - load_mask10 = 0x0; - load_mask11 = 0x00; - load_mask12 = 0x00; - load_mask13 = 0x00; - load_mask14 = 0x00; - load_mask15 = 0x00; - load_mask16 = 0x00; - } - else { - combined_mask = (0x1ull << (N - 192)) - 0x1ull; - load_mask9 = (combined_mask)&0xFF; - load_mask10 = (combined_mask >> 8) & 0xFF; - load_mask11 = (combined_mask >> 16) & 0xFF; - load_mask12 = (combined_mask >> 24) & 0xFF; - load_mask13 = (combined_mask >> 32) & 0xFF; - load_mask14 = (combined_mask >> 40) & 0xFF; - load_mask15 = (combined_mask >> 48) & 0xFF; - load_mask16 = (combined_mask >> 56) & 0xFF; - } - } - zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128); - zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136); - zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144); - zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152); - zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160); - zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168); - zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176); - zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184); - if (N < 192) { - zmm[24] = vtype::zmm_max(); - zmm[25] = vtype::zmm_max(); - zmm[26] = vtype::zmm_max(); - zmm[27] = vtype::zmm_max(); - zmm[28] = vtype::zmm_max(); - zmm[29] = vtype::zmm_max(); - zmm[30] = vtype::zmm_max(); - zmm[31] = vtype::zmm_max(); - } - else { - zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192); - zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200); - zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208); - zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216); - zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224); - zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232); - zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240); - zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248); - } - zmm[16] = sort_zmm_64bit(zmm[16]); - zmm[17] = sort_zmm_64bit(zmm[17]); - zmm[18] = sort_zmm_64bit(zmm[18]); - zmm[19] = sort_zmm_64bit(zmm[19]); - zmm[20] = sort_zmm_64bit(zmm[20]); - zmm[21] = sort_zmm_64bit(zmm[21]); - zmm[22] = sort_zmm_64bit(zmm[22]); - zmm[23] = sort_zmm_64bit(zmm[23]); - zmm[24] = sort_zmm_64bit(zmm[24]); - zmm[25] = sort_zmm_64bit(zmm[25]); - zmm[26] = sort_zmm_64bit(zmm[26]); - zmm[27] = sort_zmm_64bit(zmm[27]); - zmm[28] = sort_zmm_64bit(zmm[28]); - zmm[29] = sort_zmm_64bit(zmm[29]); - zmm[30] = sort_zmm_64bit(zmm[30]); - zmm[31] = sort_zmm_64bit(zmm[31]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); - bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); - bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); - bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); - bitonic_merge_two_zmm_64bit(zmm[16], zmm[17]); - bitonic_merge_two_zmm_64bit(zmm[18], zmm[19]); - bitonic_merge_two_zmm_64bit(zmm[20], zmm[21]); - bitonic_merge_two_zmm_64bit(zmm[22], zmm[23]); - bitonic_merge_two_zmm_64bit(zmm[24], zmm[25]); - bitonic_merge_two_zmm_64bit(zmm[26], zmm[27]); - bitonic_merge_two_zmm_64bit(zmm[28], zmm[29]); - bitonic_merge_two_zmm_64bit(zmm[30], zmm[31]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_four_zmm_64bit(zmm + 8); - bitonic_merge_four_zmm_64bit(zmm + 12); - bitonic_merge_four_zmm_64bit(zmm + 16); - bitonic_merge_four_zmm_64bit(zmm + 20); - bitonic_merge_four_zmm_64bit(zmm + 24); - bitonic_merge_four_zmm_64bit(zmm + 28); - bitonic_merge_eight_zmm_64bit(zmm); - bitonic_merge_eight_zmm_64bit(zmm + 8); - bitonic_merge_eight_zmm_64bit(zmm + 16); - bitonic_merge_eight_zmm_64bit(zmm + 24); - bitonic_merge_sixteen_zmm_64bit(zmm); - bitonic_merge_sixteen_zmm_64bit(zmm + 16); - bitonic_merge_32_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::storeu(arr + 32, zmm[4]); - vtype::storeu(arr + 40, zmm[5]); - vtype::storeu(arr + 48, zmm[6]); - vtype::storeu(arr + 56, zmm[7]); - vtype::storeu(arr + 64, zmm[8]); - vtype::storeu(arr + 72, zmm[9]); - vtype::storeu(arr + 80, zmm[10]); - vtype::storeu(arr + 88, zmm[11]); - vtype::storeu(arr + 96, zmm[12]); - vtype::storeu(arr + 104, zmm[13]); - vtype::storeu(arr + 112, zmm[14]); - vtype::storeu(arr + 120, zmm[15]); - vtype::mask_storeu(arr + 128, load_mask1, zmm[16]); - vtype::mask_storeu(arr + 136, load_mask2, zmm[17]); - vtype::mask_storeu(arr + 144, load_mask3, zmm[18]); - vtype::mask_storeu(arr + 152, load_mask4, zmm[19]); - vtype::mask_storeu(arr + 160, load_mask5, zmm[20]); - vtype::mask_storeu(arr + 168, load_mask6, zmm[21]); - vtype::mask_storeu(arr + 176, load_mask7, zmm[22]); - vtype::mask_storeu(arr + 184, load_mask8, zmm[23]); - if (N > 192) { - vtype::mask_storeu(arr + 192, load_mask9, zmm[24]); - vtype::mask_storeu(arr + 200, load_mask10, zmm[25]); - vtype::mask_storeu(arr + 208, load_mask11, zmm[26]); - vtype::mask_storeu(arr + 216, load_mask12, zmm[27]); - vtype::mask_storeu(arr + 224, load_mask13, zmm[28]); - vtype::mask_storeu(arr + 232, load_mask14, zmm[29]); - vtype::mask_storeu(arr + 240, load_mask15, zmm[30]); - vtype::mask_storeu(arr + 248, load_mask16, zmm[31]); - } -} - -template -static void -qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 256) { - sort_256_64bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_64bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_64bit_(arr, pivot_index, right, max_iters - 1); -} - -template -static void qselect_64bit_(type_t *arr, - int64_t pos, - int64_t left, - int64_t right, - int64_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 128 - */ - if (right + 1 - left <= 128) { - sort_128_64bit(arr + left, (int32_t)(right + 1 - left)); - return; - } - - type_t pivot = get_pivot_64bit(arr, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_64bit_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_64bit_(arr, pos, pivot_index, right, max_iters - 1); -} - -/* Specialized template function for 64-bit qselect_ funcs*/ -template <> -void qselect_>( - int64_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_64bit_>(arr, k, left, right, maxiters); -} - -template <> -void qselect_>( - uint64_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_64bit_>(arr, k, left, right, maxiters); -} - -template <> -void qselect_>( - double *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_64bit_>(arr, k, left, right, maxiters); -} - -/* Specialized template function for 64-bit qsort_ funcs*/ -template <> -void qsort_>(int64_t *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_64bit_>(arr, left, right, maxiters); -} - -template <> -void qsort_>(uint64_t *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_64bit_>(arr, left, right, maxiters); -} - -template <> -void qsort_>(double *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_64bit_>(arr, left, right, maxiters); -} #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index c45b6130..015a6bd7 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -13,21 +13,21 @@ #include using argtype = zmm_vector; -using argzmm_t = typename argtype::zmm_t; +using argzmm_t = typename argtype::reg_t; /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. */ -template +template static inline int32_t partition_vec(type_t *arg, int64_t left, int64_t right, const argzmm_t arg_vec, - const zmm_t curr_vec, - const zmm_t pivot_vec, - zmm_t *smallest_vec, - zmm_t *biggest_vec) + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) { /* which elements are larger than the pivot */ typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); @@ -68,14 +68,14 @@ static inline int64_t partition_avx512(type_t *arr, if (left == right) return left; /* less than vtype::numlanes elements in the array */ - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); if (right - left == vtype::numlanes) { argzmm_t argvec = argtype::loadu(arg + left); - zmm_t vec = vtype::template i64gather(argvec, arr); + reg_t vec = vtype::template i64gather(argvec, arr); int32_t amount_gt_pivot = partition_vec(arg, left, left + vtype::numlanes, @@ -91,10 +91,10 @@ static inline int64_t partition_avx512(type_t *arr, // first and last vtype::numlanes values are partitioned at the end argzmm_t argvec_left = argtype::loadu(arg + left); - zmm_t vec_left + reg_t vec_left = vtype::template i64gather(argvec_left, arr); argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - zmm_t vec_right + reg_t vec_right = vtype::template i64gather(argvec_right, arr); // store points of the vectors int64_t r_store = right - vtype::numlanes; @@ -104,7 +104,7 @@ static inline int64_t partition_avx512(type_t *arr, right -= vtype::numlanes; while (right - left != 0) { argzmm_t arg_vec; - zmm_t curr_vec; + reg_t curr_vec; /* * if fewer elements are stored on the right side of the array, * then next elements are loaded from the right side, @@ -190,13 +190,13 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, if (left == right) return left; /* less than vtype::numlanes elements in the array */ - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); // first and last vtype::numlanes values are partitioned at the end - zmm_t vec_left[num_unroll], vec_right[num_unroll]; + reg_t vec_left[num_unroll], vec_right[num_unroll]; argzmm_t argvec_left[num_unroll], argvec_right[num_unroll]; X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { @@ -216,7 +216,7 @@ X86_SIMD_SORT_UNROLL_LOOP(8) right -= num_unroll * vtype::numlanes; while (right - left != 0) { argzmm_t arg_vec[num_unroll]; - zmm_t curr_vec[num_unroll]; + reg_t curr_vec[num_unroll]; /* * if fewer elements are stored on the right side of the array, * then next elements are loaded from the right side, diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 5d3b1c2c..70b63af3 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -117,9 +117,9 @@ int64_t replace_nan_with_inf(T *arr, int64_t arrsize) { int64_t nan_count = 0; using opmask_t = typename vtype::opmask_t; - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; opmask_t loadmask; - zmm_t in; + reg_t in; while (arrsize > 0) { if (arrsize < vtype::numlanes) { loadmask = vtype::get_partial_loadmask(arrsize); @@ -141,10 +141,10 @@ template bool has_nan(type_t *arr, int64_t arrsize) { using opmask_t = typename vtype::opmask_t; - using zmm_t = typename vtype::zmm_t; + using reg_t = typename vtype::reg_t; bool found_nan = false; opmask_t loadmask; - zmm_t in; + reg_t in; while (arrsize > 0) { if (arrsize < vtype::numlanes) { loadmask = vtype::get_partial_loadmask(arrsize); @@ -218,26 +218,26 @@ static void COEX(mm_t &a, mm_t &b) b = vtype::max(temp, b); } template -static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) +static inline reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) { - zmm_t min = vtype::min(in2, in1); - zmm_t max = vtype::max(in2, in1); + reg_t min = vtype::min(in2, in1); + reg_t max = vtype::max(in2, in1); return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max } /* * Parition one ZMM register based on the pivot and returns the * number of elements that are greater than or equal to the pivot. */ -template +template static inline int32_t partition_vec(type_t *arr, int64_t left, int64_t right, - const zmm_t curr_vec, - const zmm_t pivot_vec, - zmm_t *smallest_vec, - zmm_t *biggest_vec) + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) { /* which elements are larger than or equal to the pivot */ typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); @@ -277,13 +277,13 @@ static inline int64_t partition_avx512(type_t *arr, if (left == right) return left; /* less than vtype::numlanes elements in the array */ - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); if (right - left == vtype::numlanes) { - zmm_t vec = vtype::loadu(arr + left); + reg_t vec = vtype::loadu(arr + left); int32_t amount_ge_pivot = partition_vec(arr, left, left + vtype::numlanes, @@ -297,8 +297,8 @@ static inline int64_t partition_avx512(type_t *arr, } // first and last vtype::numlanes values are partitioned at the end - zmm_t vec_left = vtype::loadu(arr + left); - zmm_t vec_right = vtype::loadu(arr + (right - vtype::numlanes)); + reg_t vec_left = vtype::loadu(arr + left); + reg_t vec_right = vtype::loadu(arr + (right - vtype::numlanes)); // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; @@ -306,7 +306,7 @@ static inline int64_t partition_avx512(type_t *arr, left += vtype::numlanes; right -= vtype::numlanes; while (right - left != 0) { - zmm_t curr_vec; + reg_t curr_vec; /* * if fewer elements are stored on the right side of the array, * then next elements are loaded from the right side, @@ -366,6 +366,11 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, type_t *smallest, type_t *biggest) { + if constexpr (num_unroll == 0) { + return partition_avx512( + arr, left, right, pivot, smallest, biggest); + } + if (right - left <= 2 * num_unroll * vtype::numlanes) { return partition_avx512( arr, left, right, pivot, smallest, biggest); @@ -386,14 +391,14 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, if (left == right) return left; /* less than vtype::numlanes elements in the array */ - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); // We will now have atleast 16 registers worth of data to process: // left and right vtype::numlanes values are partitioned at the end - zmm_t vec_left[num_unroll], vec_right[num_unroll]; + reg_t vec_left[num_unroll], vec_right[num_unroll]; X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii); @@ -407,7 +412,7 @@ X86_SIMD_SORT_UNROLL_LOOP(8) left += num_unroll * vtype::numlanes; right -= num_unroll * vtype::numlanes; while (right - left != 0) { - zmm_t curr_vec[num_unroll]; + reg_t curr_vec[num_unroll]; /* * if fewer elements are stored on the right side of the array, * then next elements are loaded from the right side, @@ -479,8 +484,8 @@ X86_SIMD_SORT_UNROLL_LOOP(8) template + typename zmm_t1 = typename vtype1::reg_t, + typename zmm_t2 = typename vtype2::reg_t> static void COEX(zmm_t1 &key1, zmm_t1 &key2, zmm_t2 &index1, zmm_t2 &index2) { zmm_t1 key_t1 = vtype1::min(key1, key2); @@ -498,8 +503,8 @@ static void COEX(zmm_t1 &key1, zmm_t1 &key2, zmm_t2 &index1, zmm_t2 &index2) } template static inline zmm_t1 cmp_merge(zmm_t1 in1, zmm_t1 in2, @@ -520,8 +525,8 @@ template + typename zmm_t1 = typename vtype1::reg_t, + typename zmm_t2 = typename vtype2::reg_t> static inline int32_t partition_vec(type_t1 *keys, type_t2 *indexes, int64_t left, @@ -555,8 +560,8 @@ template + typename zmm_t1 = typename vtype1::reg_t, + typename zmm_t2 = typename vtype2::reg_t> static inline int64_t partition_avx512(type_t1 *keys, type_t2 *indexes, int64_t left, @@ -683,14 +688,211 @@ static inline int64_t partition_avx512(type_t1 *keys, } template -void qsort_(type_t *arr, int64_t left, int64_t right, int64_t maxiters); +X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr, + const int64_t left, + const int64_t right) +{ + constexpr int64_t numSamples = vtype::numlanes; + type_t samples[numSamples]; + + int64_t delta = (right - left) / numSamples; + + for (int i = 0; i < numSamples; i++) { + samples[i] = arr[left + i * delta]; + } + + auto vec = vtype::loadu(samples); + vec = vtype::sort_vec(vec); + return ((type_t *)&vec)[numSamples / 2]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 32 + int64_t size = (right - left) / 32; + type_t vec_arr[32] = {arr[left], + arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size], + arr[left + 17 * size], + arr[left + 18 * size], + arr[left + 19 * size], + arr[left + 20 * size], + arr[left + 21 * size], + arr[left + 22 * size], + arr[left + 23 * size], + arr[left + 24 * size], + arr[left + 25 * size], + arr[left + 26 * size], + arr[left + 27 * size], + arr[left + 28 * size], + arr[left + 29 * size], + arr[left + 30 * size], + arr[left + 31 * size]}; + typename vtype::reg_t rand_vec = vtype::loadu(vec_arr); + typename vtype::reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[16]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 16 + int64_t size = (right - left) / 16; + using zmm_t = typename vtype::reg_t; + using ymm_t = typename vtype::halfreg_t; + __m512i rand_index1 = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + __m512i rand_index2 = _mm512_set_epi64(left + 9 * size, + left + 10 * size, + left + 11 * size, + left + 12 * size, + left + 13 * size, + left + 14 * size, + left + 15 * size, + left + 16 * size); + ymm_t rand_vec1 + = vtype::template i64gather(rand_index1, arr); + ymm_t rand_vec2 + = vtype::template i64gather(rand_index2, arr); + zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2); + zmm_t sort = vtype::sort_vec(rand_vec); + // pivot will never be a nan, since there are no nan's! + return ((type_t *)&sort)[8]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 8 + int64_t size = (right - left) / 8; + using zmm_t = typename vtype::reg_t; + __m512i rand_index = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + // pivot will never be a nan, since there are no nan's! + zmm_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[4]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, + const int64_t left, + const int64_t right) +{ + if constexpr (vtype::numlanes == 8) + return get_pivot_64bit(arr, left, right); + else if constexpr (vtype::numlanes == 16) + return get_pivot_32bit(arr, left, right); + else if constexpr (vtype::numlanes == 32) + return get_pivot_16bit(arr, left, right); + else + return get_pivot_scalar(arr, left, right); +} + +template +X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N); + +template +static void qsort_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold + */ + if (right + 1 - left <= vtype::network_sort_threshold) { + sort_n( + arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + + int64_t pivot_index + = partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); + + if (pivot != smallest) + qsort_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) qsort_(arr, pivot_index, right, max_iters - 1); +} template -void qselect_(type_t *arr, - int64_t pos, - int64_t left, - int64_t right, - int64_t maxiters); +static void qselect_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std::sort(arr + left, arr + right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold + */ + if (right + 1 - left <= vtype::network_sort_threshold) { + sort_n( + arr + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot(arr, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + + int64_t pivot_index + = partition_avx512_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); + + if ((pivot != smallest) && (pos < pivot_index)) + qselect_(arr, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + qselect_(arr, pos, pivot_index, right, max_iters - 1); +} // Regular quicksort routines: template @@ -750,4 +952,5 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); avx512_qsort_fp16(arr, k - 1); } + #endif // AVX512_QSORT_COMMON diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 505561c4..9874b6fd 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -8,6 +8,7 @@ #define AVX512FP16_QSORT_16BIT #include "avx512-16bit-common.h" +#include "xss-network-qsort.hpp" typedef union { _Float16 f_; @@ -17,10 +18,12 @@ typedef union { template <> struct zmm_vector<_Float16> { using type_t = _Float16; - using zmm_t = __m512h; - using ymm_t = __m256h; + using reg_t = __m512h; + using halfreg_t = __m256h; using opmask_t = __mmask32; static const uint8_t numlanes = 32; + static constexpr int network_sort_threshold = 128; + static constexpr int partition_unroll_factor = 0; static __m512i get_network(int index) { @@ -38,7 +41,7 @@ struct zmm_vector<_Float16> { val.i_ = X86_SIMD_SORT_NEGINFINITYH; return val.f_; } - static zmm_t zmm_max() + static reg_t zmm_max() { return _mm512_set1_ph(type_max()); } @@ -46,7 +49,7 @@ struct zmm_vector<_Float16> { { return _knot_mask32(x); } - static opmask_t ge(zmm_t x, zmm_t y) + static opmask_t ge(reg_t x, reg_t y) { return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); } @@ -55,75 +58,88 @@ struct zmm_vector<_Float16> { return (0x00000001 << size) - 0x00000001; } template - static opmask_t fpclass(zmm_t x) + static opmask_t fpclass(reg_t x) { return _mm512_fpclass_ph_mask(x, type); } - static zmm_t loadu(void const *mem) + static reg_t loadu(void const *mem) { return _mm512_loadu_ph(mem); } - static zmm_t max(zmm_t x, zmm_t y) + static reg_t max(reg_t x, reg_t y) { return _mm512_max_ph(x, y); } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { __m512i temp = _mm512_castph_si512(x); // AVX512_VBMI2 return _mm512_mask_compressstoreu_epi16(mem, mask, temp); } - static zmm_t maskz_loadu(opmask_t mask, void const *mem) + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, mem)); } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { // AVX512BW return _mm512_castsi512_ph( _mm512_mask_loadu_epi16(_mm512_castph_si512(x), mask, mem)); } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm512_castsi512_ph(_mm512_mask_mov_epi16( _mm512_castph_si512(x), mask, _mm512_castph_si512(y))); } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + static void mask_storeu(void *mem, opmask_t mask, reg_t x) { return _mm512_mask_storeu_epi16(mem, mask, _mm512_castph_si512(x)); } - static zmm_t min(zmm_t x, zmm_t y) + static reg_t min(reg_t x, reg_t y) { return _mm512_min_ph(x, y); } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) + static reg_t permutexvar(__m512i idx, reg_t zmm) { return _mm512_permutexvar_ph(idx, zmm); } - static type_t reducemax(zmm_t v) + static type_t reducemax(reg_t v) { return _mm512_reduce_max_ph(v); } - static type_t reducemin(zmm_t v) + static type_t reducemin(reg_t v) { return _mm512_reduce_min_ph(v); } - static zmm_t set1(type_t v) + static reg_t set1(type_t v) { return _mm512_set1_ph(v); } template - static zmm_t shuffle(zmm_t zmm) + static reg_t shuffle(reg_t zmm) { __m512i temp = _mm512_shufflehi_epi16(_mm512_castph_si512(zmm), (_MM_PERM_ENUM)mask); return _mm512_castsi512_ph( _mm512_shufflelo_epi16(temp, (_MM_PERM_ENUM)mask)); } - static void storeu(void *mem, zmm_t x) + static void storeu(void *mem, reg_t x) { return _mm512_storeu_ph(mem, x); } + static reg_t reverse(reg_t zmm) + { + const auto rev_index = get_network(4); + return permutexvar(rev_index, zmm); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_zmm_16bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_zmm_16bit>(x); + } }; template <> @@ -140,22 +156,6 @@ void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count) memset(arr + arrsize - nan_count, 0xFF, nan_count * 2); } -template <> -void qselect_>( - _Float16 *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters) -{ - qselect_16bit_>(arr, k, left, right, maxiters); -} - -template <> -void qsort_>(_Float16 *arr, - int64_t left, - int64_t right, - int64_t maxiters) -{ - qsort_16bit_>(arr, left, right, maxiters); -} - /* Specialized template function for _Float16 qsort_*/ template <> void avx512_qsort(_Float16 *arr, int64_t arrsize) @@ -164,7 +164,7 @@ void avx512_qsort(_Float16 *arr, int64_t arrsize) int64_t nan_count = replace_nan_with_inf, _Float16>(arr, arrsize); - qsort_16bit_, _Float16>( + qsort_, _Float16>( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); } diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp new file mode 100644 index 00000000..701ee774 --- /dev/null +++ b/src/xss-network-qsort.hpp @@ -0,0 +1,137 @@ +#ifndef XSS_NETWORK_QSORT +#define XSS_NETWORK_QSORT + +#include "avx512-common-qsort.h" + +template +X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(reg_t *regs) +{ +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int num = numVecs / 2; num >= 2; num /= 2) { +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int j = 0; j < numVecs; j += num) { +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < num / 2; i++) { + COEX(regs[i + j], regs[i + j + num / 2]); + } + } + } +} + +template +X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(reg_t *regs) +{ + // Do the reverse part + if constexpr (numVecs == 2) { + regs[1] = vtype::reverse(regs[1]); + COEX(regs[0], regs[1]); + } + else if constexpr (numVecs > 2) { +// Reverse upper half +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + reg_t rev = vtype::reverse(regs[numVecs - i - 1]); + reg_t maxV = vtype::max(regs[i], rev); + reg_t minV = vtype::min(regs[i], rev); + regs[numVecs - i - 1] = vtype::reverse(maxV); + regs[i] = minV; + } + } + + // Call cleaner + bitonic_clean_n_vec(regs); + +// Now do bitonic_merge +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + regs[i] = vtype::bitonic_merge(regs[i]); + } +} + +template +X86_SIMD_SORT_INLINE void bitonic_fullmerge_n_vec(reg_t *regs) +{ + if constexpr (numPer > numVecs) + return; + else { +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / numPer; i++) { + bitonic_merge_n_vec(regs + i * numPer); + } + bitonic_fullmerge_n_vec(regs); + } +} + +template +X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int32_t N) +{ + if (numVecs > 1 && N * 2 <= numVecs * vtype::numlanes) { + sort_n_vec(arr, N); + return; + } + + reg_t vecs[numVecs]; + + // Generate masks for loading and storing + typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + int64_t num_to_read + = std::min((int64_t)std::max(0, N - i * vtype::numlanes), + (int64_t)vtype::numlanes); + ioMasks[j] = ((0x1ull << num_to_read) - 0x1ull); + } + +// Unmasked part of the load +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + vecs[i] = vtype::loadu(arr + i * vtype::numlanes); + } +// Masked part of the load +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + vecs[i] = vtype::mask_loadu( + vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes); + } + +// Sort each loaded vector +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + vecs[i] = vtype::sort_vec(vecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(&vecs[0]); + +// Unmasked part of the store +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + vtype::storeu(arr + i * vtype::numlanes, vecs[i]); + } +// Masked part of the store +X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + vtype::mask_storeu(arr + i * vtype::numlanes, ioMasks[j], vecs[i]); + } +} + +template +X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) +{ + constexpr int numVecs = maxN / vtype::numlanes; + constexpr bool isMultiple = (maxN == (vtype::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be vtype::numlanes times a power of 2"); + + sort_n_vec(arr, N); +} + +#endif \ No newline at end of file