From f9e3db7f969b13f53c920ebae345d0b3badbc61d Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 13 Apr 2023 10:10:10 -0700 Subject: [PATCH 1/4] Move classes to a separate header file --- src/avx512-16bit-common.h | 15 - src/avx512-16bit-qsort.hpp | 340 ---------- src/avx512-32bit-qsort.hpp | 307 --------- src/avx512-64bit-common.h | 318 +-------- src/avx512-common-qsort.h | 54 +- src/avx512-zmm-classes.h | 1147 ++++++++++++++++++++++++++++++++ src/avx512fp16-16bit-qsort.hpp | 105 --- 7 files changed, 1150 insertions(+), 1136 deletions(-) create mode 100644 src/avx512-zmm-classes.h diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 0c819946..cace5449 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -14,21 +14,6 @@ * sorting network (see * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) */ -// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 -static const uint16_t network[6][32] - = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, - 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, - {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, - {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, - 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, - {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, - 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; - /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 606f8706..cd9a3903 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -9,346 +9,6 @@ #include "avx512-16bit-common.h" -struct float16 { - uint16_t val; -}; - -template <> -struct zmm_vector { - using type_t = uint16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static zmm_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - return X86_SIMD_SORT_INFINITYH; - } - static type_t type_min() - { - return X86_SIMD_SORT_NEGINFINITYH; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi16(type_max()); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - - static opmask_t ge(zmm_t x, zmm_t y) - { - zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); - zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); - zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); - zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); - zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); - zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); - - __mmask32 mask_ge = _mm512_cmp_epu16_mask( - sign_x, sign_y, _MM_CMPINT_LT); // only greater than - __mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y); - __mmask32 neg = _mm512_mask_cmpeq_epu16_mask( - sign_eq, - sign_x, - _mm512_set1_epi16(0x8000)); // both numbers are -ve - - // compare exponents only if signs are equal: - mask_ge = mask_ge - | _mm512_mask_cmp_epu16_mask( - sign_eq, exp_x, exp_y, _MM_CMPINT_NLE); - // get mask for elements for which both sign and exponents are equal: - __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y); - - // compare mantissa for elements for which both sign and expponent are equal: - mask_ge = mask_ge - | _mm512_mask_cmp_epu16_mask( - exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); - return _kxor_mask32(mask_ge, neg); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_mask_mov_epi16(y, ge(x, y), x); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - // AVX512_VBMI2 - return _mm512_mask_compressstoreu_epi16(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - // AVX512BW - return _mm512_mask_loadu_epi16(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi16(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_mask_mov_epi16(x, ge(x, y), y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi16(idx, zmm); - } - // Apparently this is a terrible for perf, npy_half_to_float seems to work - // better - //static float uint16_to_float(uint16_t val) - //{ - // // Ideally use _mm_loadu_si16, but its only gcc > 11.x - // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM - // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val); - // __m128 xmm2 = _mm_cvtph_ps(xmm); - // return _mm_cvtss_f32(xmm2); - //} - static type_t float_to_uint16(float val) - { - __m128 xmm = _mm_load_ss(&val); - __m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC); - return _mm_extract_epi16(xmm2, 0); - } - static type_t reducemax(zmm_t v) - { - __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); - __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); - float lo_max = _mm512_reduce_max_ps(lo); - float hi_max = _mm512_reduce_max_ps(hi); - return float_to_uint16(std::max(lo_max, hi_max)); - } - static type_t reducemin(zmm_t v) - { - __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); - __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); - float lo_max = _mm512_reduce_min_ps(lo); - float hi_max = _mm512_reduce_min_ps(hi); - return float_to_uint16(std::min(lo_max, hi_max)); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi16(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - -template <> -struct zmm_vector { - using type_t = int16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static zmm_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT16; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT16; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi16(type_max()); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi16(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - // AVX512_VBMI2 - return _mm512_mask_compressstoreu_epi16(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - // AVX512BW - return _mm512_mask_loadu_epi16(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi16(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi16(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi16(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); - type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); - return std::max(lo_max, hi_max); - } - static type_t reducemin(zmm_t v) - { - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); - type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); - return std::min(lo_min, hi_min); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi16(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = uint16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static zmm_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT16; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi16(type_max()); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu16(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi16(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi16(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi16(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu16(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi16(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); - type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); - return std::max(lo_max, hi_max); - } - static type_t reducemin(zmm_t v) - { - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); - type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); - return std::min(lo_min, hi_min); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi16(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - template <> bool comparison_func>(const uint16_t &a, const uint16_t &b) { diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index c4061ddf..a713df63 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -23,313 +23,6 @@ #define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 #define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 -template <> -struct zmm_vector { - using type_t = int32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask16; - static const uint8_t numlanes = 16; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT32; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT32; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi32(type_max()); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _mm512_knot(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); - } - template - static ymm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi32(index, base, scale); - } - static zmm_t merge(ymm_t y1, ymm_t y2) - { - zmm_t z1 = _mm512_castsi256_si512(y1); - return _mm512_inserti32x8(z1, y2, 1); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi32(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi32(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi32(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi32(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi32(x, y); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi32(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi32(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi32(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi32(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi32(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } - - static ymm_t max(ymm_t x, ymm_t y) - { - return _mm256_max_epi32(x, y); - } - static ymm_t min(ymm_t x, ymm_t y) - { - return _mm256_min_epi32(x, y); - } -}; -template <> -struct zmm_vector { - using type_t = uint32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask16; - static const uint8_t numlanes = 16; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT32; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? - - template - static ymm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi32(index, base, scale); - } - static zmm_t merge(ymm_t y1, ymm_t y2) - { - zmm_t z1 = _mm512_castsi256_si512(y1); - return _mm512_inserti32x8(z1, y2, 1); - } - static opmask_t knot_opmask(opmask_t x) - { - return _mm512_knot(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu32(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi32(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi32(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi32(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi32(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu32(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi32(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu32(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu32(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi32(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } - - static ymm_t max(ymm_t x, ymm_t y) - { - return _mm256_max_epu32(x, y); - } - static ymm_t min(ymm_t x, ymm_t y) - { - return _mm256_min_epu32(x, y); - } -}; -template <> -struct zmm_vector { - using type_t = float; - using zmm_t = __m512; - using ymm_t = __m256; - using opmask_t = __mmask16; - static const uint8_t numlanes = 16; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITYF; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITYF; - } - static zmm_t zmm_max() - { - return _mm512_set1_ps(type_max()); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _mm512_knot(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); - } - template - static ymm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_ps(index, base, scale); - } - static zmm_t merge(ymm_t y1, ymm_t y2) - { - zmm_t z1 = _mm512_castsi512_ps( - _mm512_castsi256_si512(_mm256_castps_si256(y1))); - return _mm512_insertf32x8(z1, y2, 1); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_ps(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_ps(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_ps(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_ps(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_ps(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_ps(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_ps(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_ps(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_ps(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_ps(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_ps(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_ps(mem, x); - } - - static ymm_t max(ymm_t x, ymm_t y) - { - return _mm256_max_ps(x, y); - } - static ymm_t min(ymm_t x, ymm_t y) - { - return _mm256_min_ps(x, y); - } -}; /* * Assumes zmm is random and performs a full sorting network defined in diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 7fc8acf3..87d39f15 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -19,322 +19,6 @@ #define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 #define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 -template <> -struct zmm_vector { - using type_t = int64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT64; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT64; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = uint64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT64; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = double; - using zmm_t = __m512d; - using ymm_t = __m512d; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITY; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITY; - } - static zmm_t zmm_max() - { - return _mm512_set1_pd(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_pd(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_pd(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_pd(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_pd(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_pd(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_pd(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_pd(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_pd(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_pd(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_pd(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_pd(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_pd(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_pd(mem, x); - } -}; X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) { int64_t nan_count = 0; @@ -407,4 +91,4 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, return ((type_t *)&sort)[4]; } -#endif \ No newline at end of file +#endif diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 0e0ad818..00a38732 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -33,58 +33,7 @@ * */ -#include -#include -#include -#include -#include -#include - -#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYH 0x7c00 -#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 -#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() -#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) -#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) -#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) -#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) -#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) -#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) -#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) -#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) -#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d - -#ifdef _MSC_VER -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __forceinline -#elif defined(__CYGWIN__) -/* - * Force inline in cygwin to work around a compiler bug. See - * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 - */ -#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#elif defined(__GNUC__) -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#else -#define X86_SIMD_SORT_INLINE static -#define X86_SIMD_SORT_FINLINE static -#endif - -template -struct zmm_vector; +#include "avx512-zmm-classes.h" template void avx512_qsort(T *arr, int64_t arrsize); @@ -122,6 +71,7 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } + template diff --git a/src/avx512-zmm-classes.h b/src/avx512-zmm-classes.h new file mode 100644 index 00000000..45f6cb25 --- /dev/null +++ b/src/avx512-zmm-classes.h @@ -0,0 +1,1147 @@ +#ifndef AVX512_ZMM_CLASSES +#define AVX512_ZMM_CLASSES + +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __forceinline +#elif defined(__CYGWIN__) +/* + * Force inline in cygwin to work around a compiler bug. See + * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 + */ +#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#elif defined(__GNUC__) +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#else +#define X86_SIMD_SORT_INLINE static +#define X86_SIMD_SORT_FINLINE static +#endif + +#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYH 0x7c00 +#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 +#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() +#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) +#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) +#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) +#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) +#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) +#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) +#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) +#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) +#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d + +// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 +static const uint16_t network[6][32] + = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, + {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, + {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, + 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, + {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; + +template +struct zmm_vector; + +typedef union { + _Float16 f_; + uint16_t i_; +} Fp16Bits; + +template <> +struct zmm_vector<_Float16> { + using type_t = _Float16; + using zmm_t = __m512h; + using ymm_t = __m256h; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static __m512i get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + Fp16Bits val; + val.i_ = X86_SIMD_SORT_INFINITYH; + return val.f_; + } + static type_t type_min() + { + Fp16Bits val; + val.i_ = X86_SIMD_SORT_NEGINFINITYH; + return val.f_; + } + static zmm_t zmm_max() + { + return _mm512_set1_ph(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_ph(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_ph(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + __m512i temp = _mm512_castph_si512(x); + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, temp); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_castsi512_ph( + _mm512_mask_loadu_epi16(_mm512_castph_si512(x), mask, mem)); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_castsi512_ph(_mm512_mask_mov_epi16( + _mm512_castph_si512(x), mask, _mm512_castph_si512(y))); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, _mm512_castph_si512(x)); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_ph(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_ph(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_ph(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_ph(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_ph(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512i temp = _mm512_shufflehi_epi16(_mm512_castph_si512(zmm), + (_MM_PERM_ENUM)mask); + return _mm512_castsi512_ph( + _mm512_shufflelo_epi16(temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_ph(mem, x); + } +}; + +struct float16 { + uint16_t val; +}; + +template <> +struct zmm_vector { + using type_t = uint16_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYH; + } + static type_t type_min() + { + return X86_SIMD_SORT_NEGINFINITYH; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi16(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); + zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); + zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); + zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); + zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); + zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); + + __mmask32 mask_ge = _mm512_cmp_epu16_mask( + sign_x, sign_y, _MM_CMPINT_LT); // only greater than + __mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y); + __mmask32 neg = _mm512_mask_cmpeq_epu16_mask( + sign_eq, + sign_x, + _mm512_set1_epi16(0x8000)); // both numbers are -ve + + // compare exponents only if signs are equal: + mask_ge = mask_ge + | _mm512_mask_cmp_epu16_mask( + sign_eq, exp_x, exp_y, _MM_CMPINT_NLE); + // get mask for elements for which both sign and exponents are equal: + __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y); + + // compare mantissa for elements for which both sign and expponent are equal: + mask_ge = mask_ge + | _mm512_mask_cmp_epu16_mask( + exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); + return _kxor_mask32(mask_ge, neg); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_mask_mov_epi16(y, ge(x, y), x); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_mask_loadu_epi16(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi16(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_mask_mov_epi16(x, ge(x, y), y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi16(idx, zmm); + } + // Apparently this is a terrible for perf, npy_half_to_float seems to work + // better + //static float uint16_to_float(uint16_t val) + //{ + // // Ideally use _mm_loadu_si16, but its only gcc > 11.x + // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM + // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val); + // __m128 xmm2 = _mm_cvtph_ps(xmm); + // return _mm_cvtss_f32(xmm2); + //} + static type_t float_to_uint16(float val) + { + __m128 xmm = _mm_load_ss(&val); + __m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC); + return _mm_extract_epi16(xmm2, 0); + } + static type_t reducemax(zmm_t v) + { + __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); + __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); + float lo_max = _mm512_reduce_max_ps(lo); + float hi_max = _mm512_reduce_max_ps(hi); + return float_to_uint16(std::max(lo_max, hi_max)); + } + static type_t reducemin(zmm_t v) + { + __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); + __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); + float lo_max = _mm512_reduce_min_ps(lo); + float hi_max = _mm512_reduce_min_ps(hi); + return float_to_uint16(std::min(lo_max, hi_max)); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi16(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); + return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + +template <> +struct zmm_vector { + using type_t = int16_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT16; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT16; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi16(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi16(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_mask_loadu_epi16(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi16(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi16(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi16(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); + type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); + return std::max(lo_max, hi_max); + } + static type_t reducemin(zmm_t v) + { + zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); + type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); + return std::min(lo_min, hi_min); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi16(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); + return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + +template <> +struct zmm_vector { + using type_t = uint16_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT16; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi16(type_max()); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu16(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi16(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi16(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi16(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu16(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi16(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); + type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); + return std::max(lo_max, hi_max); + } + static type_t reducemin(zmm_t v) + { + zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); + type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); + return std::min(lo_min, hi_min); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi16(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); + return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + +template <> +struct zmm_vector { + using type_t = int32_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi32(type_max()); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _mm512_knot(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); + } + template + static ymm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) + { + zmm_t z1 = _mm512_castsi256_si512(y1); + return _mm512_inserti32x8(z1, y2, 1); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi32(x, y); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi32(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi32(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi32(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi32(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) + { + return _mm256_max_epi32(x, y); + } + static ymm_t min(ymm_t x, ymm_t y) + { + return _mm256_min_epi32(x, y); + } +}; + +template <> +struct zmm_vector { + using type_t = uint32_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + + template + static ymm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) + { + zmm_t z1 = _mm512_castsi256_si512(y1); + return _mm512_inserti32x8(z1, y2, 1); + } + static opmask_t knot_opmask(opmask_t x) + { + return _mm512_knot(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu32(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu32(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu32(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi32(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) + { + return _mm256_max_epu32(x, y); + } + static ymm_t min(ymm_t x, ymm_t y) + { + return _mm256_min_epu32(x, y); + } +}; + +template <> +struct zmm_vector { + using type_t = float; + using zmm_t = __m512; + using ymm_t = __m256; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static zmm_t zmm_max() + { + return _mm512_set1_ps(type_max()); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _mm512_knot(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + } + template + static ymm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_ps(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) + { + zmm_t z1 = _mm512_castsi512_ps( + _mm512_castsi256_si512(_mm256_castps_si256(y1))); + return _mm512_insertf32x8(z1, y2, 1); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_ps(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_ps(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_ps(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_ps(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_ps(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_ps(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_ps(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_ps(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_ps(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_ps(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_ps(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) + { + return _mm256_max_ps(x, y); + } + static ymm_t min(ymm_t x, ymm_t y) + { + return _mm256_min_ps(x, y); + } +}; + +template <> +struct zmm_vector { + using type_t = int64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT64; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT64; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + +template <> +struct zmm_vector { + using type_t = uint64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT64; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + +template <> +struct zmm_vector { + using type_t = double; + using zmm_t = __m512d; + using ymm_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITY; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITY; + } + static zmm_t zmm_max() + { + return _mm512_set1_pd(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_pd(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_pd(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_pd(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_pd(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_pd(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_pd(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_pd(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_pd(mem, x); + } +}; + +#endif //AVX512_ZMM_CLASSES diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8a9a49ed..2bdf9803 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -9,111 +9,6 @@ #include "avx512-16bit-common.h" -typedef union { - _Float16 f_; - uint16_t i_; -} Fp16Bits; - -template <> -struct zmm_vector<_Float16> { - using type_t = _Float16; - using zmm_t = __m512h; - using ymm_t = __m256h; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static __m512i get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - Fp16Bits val; - val.i_ = X86_SIMD_SORT_INFINITYH; - return val.f_; - } - static type_t type_min() - { - Fp16Bits val; - val.i_ = X86_SIMD_SORT_NEGINFINITYH; - return val.f_; - } - static zmm_t zmm_max() - { - return _mm512_set1_ph(type_max()); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_ph(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_ph(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - __m512i temp = _mm512_castph_si512(x); - // AVX512_VBMI2 - return _mm512_mask_compressstoreu_epi16(mem, mask, temp); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - // AVX512BW - return _mm512_castsi512_ph( - _mm512_mask_loadu_epi16(_mm512_castph_si512(x), mask, mem)); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_castsi512_ph(_mm512_mask_mov_epi16( - _mm512_castph_si512(x), mask, _mm512_castph_si512(y))); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, _mm512_castph_si512(x)); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_ph(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_ph(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_ph(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_ph(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_ph(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512i temp = _mm512_shufflehi_epi16(_mm512_castph_si512(zmm), - (_MM_PERM_ENUM)mask); - return _mm512_castsi512_ph( - _mm512_shufflelo_epi16(temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_ph(mem, x); - } -}; - X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, int64_t arrsize) { From 4d98f32cac8687c77db367ce23bb5b7b3ca27d9b Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 14 Apr 2023 13:21:57 -0700 Subject: [PATCH 2/4] Re-write key-value sort using templates --- src/avx512-64bit-keyvaluesort.hpp | 1132 +++++++++++++++-------------- src/avx512-common-keyvaluesort.h | 240 ------ src/avx512-common-qsort.h | 216 ++++++ 3 files changed, 797 insertions(+), 791 deletions(-) delete mode 100644 src/avx512-common-keyvaluesort.h diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 8ed66e14..7d70a2b4 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -8,95 +8,90 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV -#include "avx512-common-keyvaluesort.h" +#include "avx512-64bit-common.h" -template ::zmm_t> +template X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), 0xCC); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(rev_index, key_zmm), + vtype1::permutexvar(rev_index, key_zmm), index_zmm, - zmm_vector::permutexvar(rev_index, index_zmm), + vtype2::permutexvar(rev_index, index_zmm), 0xF0); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), 0xCC); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); return key_zmm; } // Assumes zmm is bitonic and performs a recursive half cleaner -template ::zmm_t> +template X86_SIMD_SORT_INLINE zmm_t -bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) +bitonic_merge_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) { // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), 0xF0); // 2) half_cleaner[4] - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), 0xCC); // 3) half_cleaner[1] - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); return key_zmm; } // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template ::zmm_t> +template X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, zmm_t &key_zmm2, index_type &index_zmm1, @@ -104,162 +99,165 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); - index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); + key_zmm2 = vtype1::permutexvar(rev_index, key_zmm2); + index_zmm2 = vtype2::permutexvar(rev_index, index_zmm2); - zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); - zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); + zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); + zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); - index_type index_zmm3 = zmm_vector::mask_mov( - index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); - index_type index_zmm4 = zmm_vector::mask_mov( - index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); + index_type index_zmm3 = vtype2::mask_mov( + index_zmm2, vtype1::eq(key_zmm3, key_zmm1), index_zmm1); + index_type index_zmm4 = vtype2::mask_mov( + index_zmm1, vtype1::eq(key_zmm3, key_zmm1), index_zmm2); // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); index_zmm1 = index_zmm3; index_zmm2 = index_zmm4; } // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner -template ::zmm_t> +template X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network - zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); - zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); + zmm_t key_zmm2r = vtype1::permutexvar(rev_index, key_zmm[2]); + zmm_t key_zmm3r = vtype1::permutexvar(rev_index, key_zmm[3]); index_type index_zmm2r - = zmm_vector::permutexvar(rev_index, index_zmm[2]); + = vtype2::permutexvar(rev_index, index_zmm[2]); index_type index_zmm3r - = zmm_vector::permutexvar(rev_index, index_zmm[3]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); - - index_type index_zmm_t1 = zmm_vector::mask_mov( - index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); - index_type index_zmm_t2 = zmm_vector::mask_mov( - index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); + = vtype2::permutexvar(rev_index, index_zmm[3]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm3r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm2r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); // 2) Recursive half clearer: 16 - zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); + zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t4 = vtype1::permutexvar(rev_index, key_zmm_m1); index_type index_zmm_t3 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); + = vtype2::permutexvar(rev_index, index_zmm_m2); index_type index_zmm_t4 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); - zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); - - index_type index_zmm0 = zmm_vector::mask_mov( - index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); - index_type index_zmm1 = zmm_vector::mask_mov( - index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); - index_type index_zmm2 = zmm_vector::mask_mov( - index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); - index_type index_zmm3 = zmm_vector::mask_mov( - index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + = vtype2::permutexvar(rev_index, index_zmm_m1); + + zmm_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); + zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); + + index_type index_zmm0 = vtype2::mask_mov( + index_zmm_t2, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_type index_zmm1 = vtype2::mask_mov( + index_zmm_t1, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_type index_zmm2 = vtype2::mask_mov( + index_zmm_t4, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_type index_zmm3 = vtype2::mask_mov( + index_zmm_t3, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); index_zmm[0] = index_zmm0; index_zmm[1] = index_zmm1; index_zmm[2] = index_zmm2; index_zmm[3] = index_zmm3; } -template ::zmm_t> + +template X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); - zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); - zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); - zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); + zmm_t key_zmm4r = vtype1::permutexvar(rev_index, key_zmm[4]); + zmm_t key_zmm5r = vtype1::permutexvar(rev_index, key_zmm[5]); + zmm_t key_zmm6r = vtype1::permutexvar(rev_index, key_zmm[6]); + zmm_t key_zmm7r = vtype1::permutexvar(rev_index, key_zmm[7]); index_type index_zmm4r - = zmm_vector::permutexvar(rev_index, index_zmm[4]); + = vtype2::permutexvar(rev_index, index_zmm[4]); index_type index_zmm5r - = zmm_vector::permutexvar(rev_index, index_zmm[5]); + = vtype2::permutexvar(rev_index, index_zmm[5]); index_type index_zmm6r - = zmm_vector::permutexvar(rev_index, index_zmm[6]); + = vtype2::permutexvar(rev_index, index_zmm[6]); index_type index_zmm7r - = zmm_vector::permutexvar(rev_index, index_zmm[7]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); - zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm5r); - zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm4r); - - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); - zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); - - index_type index_zmm_t1 = zmm_vector::mask_mov( - index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); - index_type index_zmm_t2 = zmm_vector::mask_mov( - index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); - index_type index_zmm_t3 = zmm_vector::mask_mov( - index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = zmm_vector::mask_mov( - index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); - index_type index_zmm_t4 = zmm_vector::mask_mov( - index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = zmm_vector::mask_mov( - index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); - - zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); + = vtype2::permutexvar(rev_index, index_zmm[7]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); + zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); + zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); + + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); + zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); + zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm7r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm6r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_type index_zmm_t3 = vtype2::mask_mov( + index_zmm5r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_type index_zmm_t4 = vtype2::mask_mov( + index_zmm4r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); + + zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t7 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t8 = vtype1::permutexvar(rev_index, key_zmm_m1); index_type index_zmm_t5 - = zmm_vector::permutexvar(rev_index, index_zmm_m4); + = vtype2::permutexvar(rev_index, index_zmm_m4); index_type index_zmm_t6 - = zmm_vector::permutexvar(rev_index, index_zmm_m3); + = vtype2::permutexvar(rev_index, index_zmm_m3); index_type index_zmm_t7 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); + = vtype2::permutexvar(rev_index, index_zmm_m2); index_type index_zmm_t8 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); index_zmm[0] = index_zmm_t1; index_zmm[1] = index_zmm_t2; @@ -270,159 +268,161 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, index_zmm[6] = index_zmm_t7; index_zmm[7] = index_zmm_t8; } -template ::zmm_t> + +template X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); - zmm_t key_zmm9r = vtype::permutexvar(rev_index, key_zmm[9]); - zmm_t key_zmm10r = vtype::permutexvar(rev_index, key_zmm[10]); - zmm_t key_zmm11r = vtype::permutexvar(rev_index, key_zmm[11]); - zmm_t key_zmm12r = vtype::permutexvar(rev_index, key_zmm[12]); - zmm_t key_zmm13r = vtype::permutexvar(rev_index, key_zmm[13]); - zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); - zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); + zmm_t key_zmm8r = vtype1::permutexvar(rev_index, key_zmm[8]); + zmm_t key_zmm9r = vtype1::permutexvar(rev_index, key_zmm[9]); + zmm_t key_zmm10r = vtype1::permutexvar(rev_index, key_zmm[10]); + zmm_t key_zmm11r = vtype1::permutexvar(rev_index, key_zmm[11]); + zmm_t key_zmm12r = vtype1::permutexvar(rev_index, key_zmm[12]); + zmm_t key_zmm13r = vtype1::permutexvar(rev_index, key_zmm[13]); + zmm_t key_zmm14r = vtype1::permutexvar(rev_index, key_zmm[14]); + zmm_t key_zmm15r = vtype1::permutexvar(rev_index, key_zmm[15]); index_type index_zmm8r - = zmm_vector::permutexvar(rev_index, index_zmm[8]); + = vtype2::permutexvar(rev_index, index_zmm[8]); index_type index_zmm9r - = zmm_vector::permutexvar(rev_index, index_zmm[9]); + = vtype2::permutexvar(rev_index, index_zmm[9]); index_type index_zmm10r - = zmm_vector::permutexvar(rev_index, index_zmm[10]); + = vtype2::permutexvar(rev_index, index_zmm[10]); index_type index_zmm11r - = zmm_vector::permutexvar(rev_index, index_zmm[11]); + = vtype2::permutexvar(rev_index, index_zmm[11]); index_type index_zmm12r - = zmm_vector::permutexvar(rev_index, index_zmm[12]); + = vtype2::permutexvar(rev_index, index_zmm[12]); index_type index_zmm13r - = zmm_vector::permutexvar(rev_index, index_zmm[13]); + = vtype2::permutexvar(rev_index, index_zmm[13]); index_type index_zmm14r - = zmm_vector::permutexvar(rev_index, index_zmm[14]); + = vtype2::permutexvar(rev_index, index_zmm[14]); index_type index_zmm15r - = zmm_vector::permutexvar(rev_index, index_zmm[15]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); - zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); - zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm12r); - zmm_t key_zmm_t5 = vtype::min(key_zmm[4], key_zmm11r); - zmm_t key_zmm_t6 = vtype::min(key_zmm[5], key_zmm10r); - zmm_t key_zmm_t7 = vtype::min(key_zmm[6], key_zmm9r); - zmm_t key_zmm_t8 = vtype::min(key_zmm[7], key_zmm8r); - - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm15r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm14r); - zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm13r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm12r); - zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); - zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); - zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); - - index_type index_zmm_t1 = zmm_vector::mask_mov( - index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); - index_type index_zmm_t2 = zmm_vector::mask_mov( - index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); - index_type index_zmm_t3 = zmm_vector::mask_mov( - index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = zmm_vector::mask_mov( - index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); - index_type index_zmm_t4 = zmm_vector::mask_mov( - index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = zmm_vector::mask_mov( - index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); - - index_type index_zmm_t5 = zmm_vector::mask_mov( - index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); - index_type index_zmm_m5 = zmm_vector::mask_mov( - index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); - index_type index_zmm_t6 = zmm_vector::mask_mov( - index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); - index_type index_zmm_m6 = zmm_vector::mask_mov( - index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); - index_type index_zmm_t7 = zmm_vector::mask_mov( - index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); - index_type index_zmm_m7 = zmm_vector::mask_mov( - index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); - index_type index_zmm_t8 = zmm_vector::mask_mov( - index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); - index_type index_zmm_m8 = zmm_vector::mask_mov( - index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); - - zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); - zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); - zmm_t key_zmm_t11 = vtype::permutexvar(rev_index, key_zmm_m6); - zmm_t key_zmm_t12 = vtype::permutexvar(rev_index, key_zmm_m5); - zmm_t key_zmm_t13 = vtype::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); + = vtype2::permutexvar(rev_index, index_zmm[15]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); + zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm13r); + zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm12r); + zmm_t key_zmm_t5 = vtype1::min(key_zmm[4], key_zmm11r); + zmm_t key_zmm_t6 = vtype1::min(key_zmm[5], key_zmm10r); + zmm_t key_zmm_t7 = vtype1::min(key_zmm[6], key_zmm9r); + zmm_t key_zmm_t8 = vtype1::min(key_zmm[7], key_zmm8r); + + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); + zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); + zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); + zmm_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); + zmm_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); + zmm_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); + zmm_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm15r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm14r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_type index_zmm_t3 = vtype2::mask_mov( + index_zmm13r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_type index_zmm_t4 = vtype2::mask_mov( + index_zmm12r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_type index_zmm_t5 = vtype2::mask_mov( + index_zmm11r, vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_type index_zmm_m5 = vtype2::mask_mov( + index_zmm[4], vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_type index_zmm_t6 = vtype2::mask_mov( + index_zmm10r, vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_type index_zmm_m6 = vtype2::mask_mov( + index_zmm[5], vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_type index_zmm_t7 = vtype2::mask_mov( + index_zmm9r, vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_type index_zmm_m7 = vtype2::mask_mov( + index_zmm[6], vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_type index_zmm_t8 = vtype2::mask_mov( + index_zmm8r, vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_type index_zmm_m8 = vtype2::mask_mov( + index_zmm[7], vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); + + zmm_t key_zmm_t9 = vtype1::permutexvar(rev_index, key_zmm_m8); + zmm_t key_zmm_t10 = vtype1::permutexvar(rev_index, key_zmm_m7); + zmm_t key_zmm_t11 = vtype1::permutexvar(rev_index, key_zmm_m6); + zmm_t key_zmm_t12 = vtype1::permutexvar(rev_index, key_zmm_m5); + zmm_t key_zmm_t13 = vtype1::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t14 = vtype1::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t15 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t16 = vtype1::permutexvar(rev_index, key_zmm_m1); index_type index_zmm_t9 - = zmm_vector::permutexvar(rev_index, index_zmm_m8); + = vtype2::permutexvar(rev_index, index_zmm_m8); index_type index_zmm_t10 - = zmm_vector::permutexvar(rev_index, index_zmm_m7); + = vtype2::permutexvar(rev_index, index_zmm_m7); index_type index_zmm_t11 - = zmm_vector::permutexvar(rev_index, index_zmm_m6); + = vtype2::permutexvar(rev_index, index_zmm_m6); index_type index_zmm_t12 - = zmm_vector::permutexvar(rev_index, index_zmm_m5); + = vtype2::permutexvar(rev_index, index_zmm_m5); index_type index_zmm_t13 - = zmm_vector::permutexvar(rev_index, index_zmm_m4); + = vtype2::permutexvar(rev_index, index_zmm_m4); index_type index_zmm_t14 - = zmm_vector::permutexvar(rev_index, index_zmm_m3); + = vtype2::permutexvar(rev_index, index_zmm_m3); index_type index_zmm_t15 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); + = vtype2::permutexvar(rev_index, index_zmm_m2); index_type index_zmm_t16 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); - COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); - COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); - COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); - COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); - COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); - COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); - COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); // - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); index_zmm[0] = index_zmm_t1; index_zmm[1] = index_zmm_t2; @@ -441,135 +441,148 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_zmm[14] = index_zmm_t15; index_zmm[15] = index_zmm_t16; } -template + +template X86_SIMD_SORT_INLINE void -sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) { - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype::zmm_t key_zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); - - zmm_vector::zmm_t index_zmm = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask, indexes); - vtype::mask_storeu( - keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); - zmm_vector::mask_storeu(indexes, load_mask, index_zmm); + typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; + typename vtype1::zmm_t key_zmm + = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); + + typename vtype2::zmm_t index_zmm = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask, indexes); + vtype1::mask_storeu( + keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); + vtype2::mask_storeu(indexes, load_mask, index_zmm); } -template +template X86_SIMD_SORT_INLINE void -sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 8) { - sort_8_64bit(keys, indexes, N); + sort_8_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using index_type = zmm_vector::zmm_t; + using zmm_t = typename vtype1::zmm_t; + using index_type = typename vtype2::zmm_t; - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t key_zmm1 = vtype::loadu(keys); - zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); + zmm_t key_zmm1 = vtype1::loadu(keys); + zmm_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); - index_type index_zmm1 = zmm_vector::loadu(indexes); - index_type index_zmm2 = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask, indexes + 8); + index_type index_zmm1 = vtype2::loadu(indexes); + index_type index_zmm2 = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask, indexes + 8); - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( key_zmm1, key_zmm2, index_zmm1, index_zmm2); - zmm_vector::storeu(indexes, index_zmm1); - zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); + vtype2::storeu(indexes, index_zmm1); + vtype2::mask_storeu(indexes + 8, load_mask, index_zmm2); - vtype::storeu(keys, key_zmm1); - vtype::mask_storeu(keys + 8, load_mask, key_zmm2); + vtype1::storeu(keys, key_zmm1); + vtype1::mask_storeu(keys + 8, load_mask, key_zmm2); } -template +template X86_SIMD_SORT_INLINE void -sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 16) { - sort_16_64bit(keys, indexes, N); + sort_16_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - using index_type = zmm_vector::zmm_t; + using zmm_t = typename vtype1::zmm_t; + using opmask_t = typename vtype2::opmask_t; + using index_type = typename vtype2::zmm_t; zmm_t key_zmm[4]; index_type index_zmm[4]; - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; load_mask1 = (combined_mask)&0xFF; load_mask2 = (combined_mask >> 8) & 0xFF; - key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); - key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); + key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); + key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); - index_zmm[2] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 24); + index_zmm[2] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); + vtype2::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::mask_storeu(keys + 16, load_mask1, key_zmm[2]); - vtype::mask_storeu(keys + 24, load_mask2, key_zmm[3]); + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::mask_storeu(keys + 16, load_mask1, key_zmm[2]); + vtype1::mask_storeu(keys + 24, load_mask2, key_zmm[3]); } -template +template X86_SIMD_SORT_INLINE void -sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 32) { - sort_32_64bit(keys, indexes, N); + sort_32_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - using index_type = zmm_vector::zmm_t; + using zmm_t = typename vtype1::zmm_t; + using opmask_t = typename vtype1::opmask_t; + using index_type = typename vtype2::zmm_t; zmm_t key_zmm[8]; index_type index_zmm[8]; - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - key_zmm[2] = vtype::loadu(keys + 16); - key_zmm[3] = vtype::loadu(keys + 24); + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + key_zmm[2] = vtype1::loadu(keys + 16); + key_zmm[3] = vtype1::loadu(keys + 24); - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - index_zmm[2] = zmm_vector::loadu(indexes + 16); - index_zmm[3] = zmm_vector::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + index_zmm[2] = vtype2::loadu(indexes + 16); + index_zmm[3] = vtype2::loadu(indexes + 24); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; @@ -579,94 +592,97 @@ sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) load_mask2 = (combined_mask >> 8) & 0xFF; load_mask3 = (combined_mask >> 16) & 0xFF; load_mask4 = (combined_mask >> 24) & 0xFF; - key_zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 32); - key_zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 40); - key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); - key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); - - index_zmm[4] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - bitonic_merge_two_zmm_64bit( + key_zmm[4] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 32); + key_zmm[5] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 40); + key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); + key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); + + index_zmm[4] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::storeu(indexes + 16, index_zmm[2]); - zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::storeu(keys + 16, key_zmm[2]); - vtype::storeu(keys + 24, key_zmm[3]); - vtype::mask_storeu(keys + 32, load_mask1, key_zmm[4]); - vtype::mask_storeu(keys + 40, load_mask2, key_zmm[5]); - vtype::mask_storeu(keys + 48, load_mask3, key_zmm[6]); - vtype::mask_storeu(keys + 56, load_mask4, key_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::storeu(indexes + 16, index_zmm[2]); + vtype2::storeu(indexes + 24, index_zmm[3]); + vtype2::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); + vtype2::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); + vtype2::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); + vtype2::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::storeu(keys + 16, key_zmm[2]); + vtype1::storeu(keys + 24, key_zmm[3]); + vtype1::mask_storeu(keys + 32, load_mask1, key_zmm[4]); + vtype1::mask_storeu(keys + 40, load_mask2, key_zmm[5]); + vtype1::mask_storeu(keys + 48, load_mask3, key_zmm[6]); + vtype1::mask_storeu(keys + 56, load_mask4, key_zmm[7]); } -template +template X86_SIMD_SORT_INLINE void -sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 64) { - sort_64_64bit(keys, indexes, N); + sort_64_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using index_type = zmm_vector::zmm_t; - using opmask_t = typename vtype::opmask_t; + using zmm_t = typename vtype1::zmm_t; + using index_type = typename vtype2::zmm_t; + using opmask_t = typename vtype1::opmask_t; zmm_t key_zmm[16]; index_type index_zmm[16]; - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - key_zmm[2] = vtype::loadu(keys + 16); - key_zmm[3] = vtype::loadu(keys + 24); - key_zmm[4] = vtype::loadu(keys + 32); - key_zmm[5] = vtype::loadu(keys + 40); - key_zmm[6] = vtype::loadu(keys + 48); - key_zmm[7] = vtype::loadu(keys + 56); - - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - index_zmm[2] = zmm_vector::loadu(indexes + 16); - index_zmm[3] = zmm_vector::loadu(indexes + 24); - index_zmm[4] = zmm_vector::loadu(indexes + 32); - index_zmm[5] = zmm_vector::loadu(indexes + 40); - index_zmm[6] = zmm_vector::loadu(indexes + 48); - index_zmm[7] = zmm_vector::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + key_zmm[2] = vtype1::loadu(keys + 16); + key_zmm[3] = vtype1::loadu(keys + 24); + key_zmm[4] = vtype1::loadu(keys + 32); + key_zmm[5] = vtype1::loadu(keys + 40); + key_zmm[6] = vtype1::loadu(keys + 48); + key_zmm[7] = vtype1::loadu(keys + 56); + + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + index_zmm[2] = vtype2::loadu(indexes + 16); + index_zmm[3] = vtype2::loadu(indexes + 24); + index_zmm[4] = vtype2::loadu(indexes + 32); + index_zmm[5] = vtype2::loadu(indexes + 40); + index_zmm[6] = vtype2::loadu(indexes + 48); + index_zmm[7] = vtype2::loadu(indexes + 56); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; @@ -683,100 +699,103 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) load_mask7 = (combined_mask >> 48) & 0xFF; load_mask8 = (combined_mask >> 56) & 0xFF; } - key_zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 64); - key_zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 72); - key_zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 80); - key_zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 88); - key_zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, keys + 96); - key_zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, keys + 104); - key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); - key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); - - index_zmm[8] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( + key_zmm[8] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 64); + key_zmm[9] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 72); + key_zmm[10] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 80); + key_zmm[11] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 88); + key_zmm[12] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask5, keys + 96); + key_zmm[13] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask6, keys + 104); + key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); + key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); + + index_zmm[8] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = vtype2::mask_loadu( + vtype2::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::storeu(indexes + 16, index_zmm[2]); - zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::storeu(indexes + 32, index_zmm[4]); - zmm_vector::storeu(indexes + 40, index_zmm[5]); - zmm_vector::storeu(indexes + 48, index_zmm[6]); - zmm_vector::storeu(indexes + 56, index_zmm[7]); - zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::storeu(keys + 16, key_zmm[2]); - vtype::storeu(keys + 24, key_zmm[3]); - vtype::storeu(keys + 32, key_zmm[4]); - vtype::storeu(keys + 40, key_zmm[5]); - vtype::storeu(keys + 48, key_zmm[6]); - vtype::storeu(keys + 56, key_zmm[7]); - vtype::mask_storeu(keys + 64, load_mask1, key_zmm[8]); - vtype::mask_storeu(keys + 72, load_mask2, key_zmm[9]); - vtype::mask_storeu(keys + 80, load_mask3, key_zmm[10]); - vtype::mask_storeu(keys + 88, load_mask4, key_zmm[11]); - vtype::mask_storeu(keys + 96, load_mask5, key_zmm[12]); - vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); - vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); - vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::storeu(indexes + 16, index_zmm[2]); + vtype2::storeu(indexes + 24, index_zmm[3]); + vtype2::storeu(indexes + 32, index_zmm[4]); + vtype2::storeu(indexes + 40, index_zmm[5]); + vtype2::storeu(indexes + 48, index_zmm[6]); + vtype2::storeu(indexes + 56, index_zmm[7]); + vtype2::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); + vtype2::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); + vtype2::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); + vtype2::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); + vtype2::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); + vtype2::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); + vtype2::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); + vtype2::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::storeu(keys + 16, key_zmm[2]); + vtype1::storeu(keys + 24, key_zmm[3]); + vtype1::storeu(keys + 32, key_zmm[4]); + vtype1::storeu(keys + 40, key_zmm[5]); + vtype1::storeu(keys + 48, key_zmm[6]); + vtype1::storeu(keys + 56, key_zmm[7]); + vtype1::mask_storeu(keys + 64, load_mask1, key_zmm[8]); + vtype1::mask_storeu(keys + 72, load_mask2, key_zmm[9]); + vtype1::mask_storeu(keys + 80, load_mask3, key_zmm[10]); + vtype1::mask_storeu(keys + 88, load_mask4, key_zmm[11]); + vtype1::mask_storeu(keys + 96, load_mask5, key_zmm[12]); + vtype1::mask_storeu(keys + 104, load_mask6, key_zmm[13]); + vtype1::mask_storeu(keys + 112, load_mask7, key_zmm[14]); + vtype1::mask_storeu(keys + 120, load_mask8, key_zmm[15]); } -template -void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) +template +void heapify(type1_t *keys, type2_t *indexes, int64_t idx, int64_t size) { int64_t i = idx; while (true) { @@ -790,22 +809,28 @@ void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) i = j; } } -template -void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) +template +void heap_sort(type1_t *keys, type2_t *indexes, int64_t size) { for (int64_t i = size / 2 - 1; i >= 0; i--) { - heapify(keys, indexes, i, size); + heapify(keys, indexes, i, size); } for (int64_t i = size - 1; i > 0; i--) { std::swap(keys[0], keys[i]); std::swap(indexes[0], indexes[i]); - heapify(keys, indexes, 0, i); + heapify(keys, indexes, 0, i); } } -template -void qsort_64bit_(type_t *keys, - uint64_t *indexes, +template +void qsort_64bit_(type1_t *keys, + type2_t *indexes, int64_t left, int64_t right, int64_t max_iters) @@ -815,7 +840,7 @@ void qsort_64bit_(type_t *keys, */ if (max_iters <= 0) { //std::sort(keys+left,keys+right+1); - heap_sort(keys + left, indexes + left, right - left + 1); + heap_sort(keys + left, indexes + left, right - left + 1); return; } /* @@ -823,30 +848,33 @@ void qsort_64bit_(type_t *keys, */ if (right + 1 - left <= 128) { - sort_128_64bit( + sort_128_64bit( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot_64bit(keys, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + type1_t pivot = get_pivot_64bit(keys, left, right); + type1_t smallest = vtype1::type_max(); + type1_t biggest = vtype1::type_min(); + int64_t pivot_index = partition_avx512( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) { - qsort_64bit_( + qsort_64bit_( keys, indexes, left, pivot_index - 1, max_iters - 1); } if (pivot != biggest) { - qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); + qsort_64bit_( + keys, indexes, pivot_index, right, max_iters - 1); } } template <> -void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort_kv(int64_t *keys, + uint64_t *indexes, + int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, int64_t>( + qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } @@ -857,17 +885,19 @@ void avx512_qsort_kv(uint64_t *keys, int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, uint64_t>( + qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort_kv(double *keys, + uint64_t *indexes, + int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(keys, arrsize); - qsort_64bit_, double>( + qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(keys, arrsize, nan_count); } diff --git a/src/avx512-common-keyvaluesort.h b/src/avx512-common-keyvaluesort.h deleted file mode 100644 index f2821072..00000000 --- a/src/avx512-common-keyvaluesort.h +++ /dev/null @@ -1,240 +0,0 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * Copyright (C) 2021 Serge Sans Paille - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Liu Zhuan - * Tang Xi - * ****************************************************************/ - -#ifndef AVX512_QSORT_COMMON_KV -#define AVX512_QSORT_COMMON_KV - -/* - * Quicksort using AVX-512. The ideas and code are based on these two research - * papers [1] and [2]. On a high level, the idea is to vectorize quicksort - * partitioning using AVX-512 compressstore instructions. If the array size is - * < 128, then use Bitonic sorting network implemented on 512-bit registers. - * The precise network definitions depend on the dtype and are defined in - * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and - * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting - * network. The core implementations of the vectorized qsort functions - * avx512_qsort(T*, int64_t) are modified versions of avx2 quicksort - * presented in the paper [2] and source code associated with that paper [3]. - * - * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types - * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ - * - * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel - * Skylake https://arxiv.org/pdf/1704.08579.pdf - * - * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: MIT - * - * [4] http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 - * - */ - -#include "avx512-64bit-common.h" - -template -void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); - -using index_t = __m512i; - -template > -static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) -{ - mm_t key_t1 = vtype::min(key1, key2); - mm_t key_t2 = vtype::max(key1, key2); - - index_t index_t1 - = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); - index_t index_t2 - = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); - - key1 = key_t1; - key2 = key_t2; - index1 = index_t1; - index2 = index_t2; -} -template > -static inline zmm_t cmp_merge(zmm_t in1, - zmm_t in2, - index_t &indexes1, - index_t indexes2, - opmask_t mask) -{ - zmm_t tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = index_type::mask_mov( - indexes2, vtype::eq(tmp_keys, in1), indexes1); - return tmp_keys; // 0 -> min, 1 -> max -} -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template > -static inline int32_t partition_vec(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - const zmm_t keys_vec, - const index_t indexes_vec, - const zmm_t pivot_vec, - zmm_t *smallest_vec, - zmm_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(keys_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - vtype::mask_compressstoreu( - keys + left, vtype::knot_opmask(gt_mask), keys_vec); - vtype::mask_compressstoreu( - keys + right - amount_gt_pivot, gt_mask, keys_vec); - index_type::mask_compressstoreu( - indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); - index_type::mask_compressstoreu( - indexes + right - amount_gt_pivot, gt_mask, indexes_vec); - *smallest_vec = vtype::min(keys_vec, *smallest_vec); - *biggest_vec = vtype::max(keys_vec, *biggest_vec); - return amount_gt_pivot; -} -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template > -static inline int64_t partition_avx512(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, keys[left]); - *biggest = std::max(*biggest, keys[left]); - if (keys[left] > pivot) { - right--; - std::swap(keys[left], keys[right]); - std::swap(indexes[left], indexes[right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - zmm_t keys_vec = vtype::loadu(keys + left); - int32_t amount_gt_pivot; - - index_t indexes_vec = index_type::loadu(indexes + left); - amount_gt_pivot = partition_vec(keys, - indexes, - left, - left + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - zmm_t keys_vec_left = vtype::loadu(keys + left); - zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); - index_t indexes_vec_left; - index_t indexes_vec_right; - indexes_vec_left = index_type::loadu(indexes + left); - indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); - - // store points of the vectors - int64_t r_store = right - vtype::numlanes; - int64_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - zmm_t keys_vec; - index_t indexes_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - keys_vec = vtype::loadu(keys + right); - indexes_vec = index_type::loadu(indexes + right); - } - else { - keys_vec = vtype::loadu(keys + left); - indexes_vec = index_type::loadu(indexes + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot; - - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot; - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - indexes_vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - indexes_vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} -#endif // AVX512_QSORT_COMMON_KV diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 00a38732..9ac0ccdd 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -4,6 +4,8 @@ * SPDX-License-Identifier: BSD-3-Clause * Authors: Raghuveer Devulapalli * Serge Sans Paille + * Liu Zhuan + * Tang Xi * ****************************************************************/ #ifndef AVX512_QSORT_COMMON @@ -35,6 +37,7 @@ #include "avx512-zmm-classes.h" +// Regular quicksort routines: template void avx512_qsort(T *arr, int64_t arrsize); void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); @@ -55,6 +58,10 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) avx512_qsort_fp16(arr, k - 1); } +// key-value sort routines +template +void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); + template bool comparison_func(const T &a, const T &b) { @@ -329,4 +336,213 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, *biggest = vtype::reducemax(max_vec); return l_store; } + +// Key-value sort helper functions + +template +static void COEX(zmm_t1 &key1, zmm_t1 &key2, zmm_t2 &index1, zmm_t2 &index2) +{ + zmm_t1 key_t1 = vtype1::min(key1, key2); + zmm_t1 key_t2 = vtype1::max(key1, key2); + + zmm_t2 index_t1 + = vtype2::mask_mov(index2, vtype1::eq(key_t1, key1), index1); + zmm_t2 index_t2 + = vtype2::mask_mov(index1, vtype1::eq(key_t1, key1), index2); + + key1 = key_t1; + key2 = key_t2; + index1 = index_t1; + index2 = index_t2; +} +template +static inline zmm_t1 cmp_merge(zmm_t1 in1, + zmm_t1 in2, + zmm_t2 &indexes1, + zmm_t2 indexes2, + opmask_t mask) +{ + zmm_t1 tmp_keys = cmp_merge(in1, in2, mask); + indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); + return tmp_keys; // 0 -> min, 1 -> max +} + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +static inline int32_t partition_vec(type_t1 *keys, + type_t2 *indexes, + int64_t left, + int64_t right, + const zmm_t1 keys_vec, + const zmm_t2 indexes_vec, + const zmm_t1 pivot_vec, + zmm_t1 *smallest_vec, + zmm_t1 *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype1::opmask_t gt_mask = vtype1::ge(keys_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + vtype1::mask_compressstoreu( + keys + left, vtype1::knot_opmask(gt_mask), keys_vec); + vtype1::mask_compressstoreu( + keys + right - amount_gt_pivot, gt_mask, keys_vec); + vtype2::mask_compressstoreu( + indexes + left, vtype2::knot_opmask(gt_mask), indexes_vec); + vtype2::mask_compressstoreu( + indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + *smallest_vec = vtype1::min(keys_vec, *smallest_vec); + *biggest_vec = vtype1::max(keys_vec, *biggest_vec); + return amount_gt_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +static inline int64_t partition_avx512(type_t1 *keys, + type_t2 *indexes, + int64_t left, + int64_t right, + type_t1 pivot, + type_t1 *smallest, + type_t1 *biggest) +{ + /* make array length divisible by vtype1::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype1::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, keys[left]); + *biggest = std::max(*biggest, keys[left]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + std::swap(indexes[left], indexes[right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype1::numlanes elements in the array */ + + zmm_t1 pivot_vec = vtype1::set1(pivot); + zmm_t1 min_vec = vtype1::set1(*smallest); + zmm_t1 max_vec = vtype1::set1(*biggest); + + if (right - left == vtype1::numlanes) { + zmm_t1 keys_vec = vtype1::loadu(keys + left); + int32_t amount_gt_pivot; + + zmm_t2 indexes_vec = vtype2::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype1::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + + *smallest = vtype1::reducemin(min_vec); + *biggest = vtype1::reducemax(max_vec); + return left + (vtype1::numlanes - amount_gt_pivot); + } + + // first and last vtype1::numlanes values are partitioned at the end + zmm_t1 keys_vec_left = vtype1::loadu(keys + left); + zmm_t1 keys_vec_right = vtype1::loadu(keys + (right - vtype1::numlanes)); + zmm_t2 indexes_vec_left; + zmm_t2 indexes_vec_right; + indexes_vec_left = vtype2::loadu(indexes + left); + indexes_vec_right = vtype2::loadu(indexes + (right - vtype1::numlanes)); + + // store points of the vectors + int64_t r_store = right - vtype1::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype1::numlanes; + right -= vtype1::numlanes; + while (right - left != 0) { + zmm_t1 keys_vec; + zmm_t2 indexes_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype1::numlanes) - right < left - l_store) { + right -= vtype1::numlanes; + keys_vec = vtype1::loadu(keys + right); + indexes_vec = vtype2::loadu(indexes + right); + } + else { + keys_vec = vtype1::loadu(keys + left); + indexes_vec = vtype2::loadu(indexes + left); + left += vtype1::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot; + + amount_gt_pivot = partition_vec( + keys, + indexes, + l_store, + r_store + vtype1::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + r_store -= amount_gt_pivot; + l_store += (vtype1::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot; + amount_gt_pivot = partition_vec( + keys, + indexes, + l_store, + r_store + vtype1::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype1::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec( + keys, + indexes, + l_store, + l_store + vtype1::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype1::numlanes - amount_gt_pivot); + *smallest = vtype1::reducemin(min_vec); + *biggest = vtype1::reducemax(max_vec); + return l_store; +} #endif // AVX512_QSORT_COMMON From 8c2066aadad6585cb61d82dfbf938b41df22742b Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 14 Apr 2023 13:25:14 -0700 Subject: [PATCH 3/4] Fix formatting --- src/avx512-32bit-qsort.hpp | 1 - src/avx512-64bit-keyvaluesort.hpp | 453 +++++++++++++++--------------- src/avx512-common-qsort.h | 58 ++-- 3 files changed, 256 insertions(+), 256 deletions(-) diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index a713df63..0f3b85a1 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -23,7 +23,6 @@ #define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 #define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 - /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 7d70a2b4..4c75c481 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -60,8 +60,8 @@ template -X86_SIMD_SORT_INLINE zmm_t -bitonic_merge_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, + index_type &index_zmm) { // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 @@ -129,10 +129,8 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, // 1) First step of a merging network zmm_t key_zmm2r = vtype1::permutexvar(rev_index, key_zmm[2]); zmm_t key_zmm3r = vtype1::permutexvar(rev_index, key_zmm[3]); - index_type index_zmm2r - = vtype2::permutexvar(rev_index, index_zmm[2]); - index_type index_zmm3r - = vtype2::permutexvar(rev_index, index_zmm[3]); + index_type index_zmm2r = vtype2::permutexvar(rev_index, index_zmm[2]); + index_type index_zmm3r = vtype2::permutexvar(rev_index, index_zmm[3]); zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); @@ -151,10 +149,8 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, // 2) Recursive half clearer: 16 zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t4 = vtype1::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t3 - = vtype2::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t4 - = vtype2::permutexvar(rev_index, index_zmm_m1); + index_type index_zmm_t3 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t4 = vtype2::permutexvar(rev_index, index_zmm_m1); zmm_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); zmm_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); @@ -193,14 +189,10 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm5r = vtype1::permutexvar(rev_index, key_zmm[5]); zmm_t key_zmm6r = vtype1::permutexvar(rev_index, key_zmm[6]); zmm_t key_zmm7r = vtype1::permutexvar(rev_index, key_zmm[7]); - index_type index_zmm4r - = vtype2::permutexvar(rev_index, index_zmm[4]); - index_type index_zmm5r - = vtype2::permutexvar(rev_index, index_zmm[5]); - index_type index_zmm6r - = vtype2::permutexvar(rev_index, index_zmm[6]); - index_type index_zmm7r - = vtype2::permutexvar(rev_index, index_zmm[7]); + index_type index_zmm4r = vtype2::permutexvar(rev_index, index_zmm[4]); + index_type index_zmm5r = vtype2::permutexvar(rev_index, index_zmm[5]); + index_type index_zmm6r = vtype2::permutexvar(rev_index, index_zmm[6]); + index_type index_zmm7r = vtype2::permutexvar(rev_index, index_zmm[7]); zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); @@ -233,31 +225,35 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t7 = vtype1::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t8 = vtype1::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t5 - = vtype2::permutexvar(rev_index, index_zmm_m4); - index_type index_zmm_t6 - = vtype2::permutexvar(rev_index, index_zmm_m3); - index_type index_zmm_t7 - = vtype2::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t8 - = vtype2::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + index_type index_zmm_t5 = vtype2::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t6 = vtype2::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t7 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t8 = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] + = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); index_zmm[0] = index_zmm_t1; index_zmm[1] = index_zmm_t2; @@ -286,22 +282,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm14r = vtype1::permutexvar(rev_index, key_zmm[14]); zmm_t key_zmm15r = vtype1::permutexvar(rev_index, key_zmm[15]); - index_type index_zmm8r - = vtype2::permutexvar(rev_index, index_zmm[8]); - index_type index_zmm9r - = vtype2::permutexvar(rev_index, index_zmm[9]); - index_type index_zmm10r - = vtype2::permutexvar(rev_index, index_zmm[10]); - index_type index_zmm11r - = vtype2::permutexvar(rev_index, index_zmm[11]); - index_type index_zmm12r - = vtype2::permutexvar(rev_index, index_zmm[12]); - index_type index_zmm13r - = vtype2::permutexvar(rev_index, index_zmm[13]); - index_type index_zmm14r - = vtype2::permutexvar(rev_index, index_zmm[14]); - index_type index_zmm15r - = vtype2::permutexvar(rev_index, index_zmm[15]); + index_type index_zmm8r = vtype2::permutexvar(rev_index, index_zmm[8]); + index_type index_zmm9r = vtype2::permutexvar(rev_index, index_zmm[9]); + index_type index_zmm10r = vtype2::permutexvar(rev_index, index_zmm[10]); + index_type index_zmm11r = vtype2::permutexvar(rev_index, index_zmm[11]); + index_type index_zmm12r = vtype2::permutexvar(rev_index, index_zmm[12]); + index_type index_zmm13r = vtype2::permutexvar(rev_index, index_zmm[13]); + index_type index_zmm14r = vtype2::permutexvar(rev_index, index_zmm[14]); + index_type index_zmm15r = vtype2::permutexvar(rev_index, index_zmm[15]); zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); @@ -363,66 +351,83 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm_t14 = vtype1::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t15 = vtype1::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t16 = vtype1::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t9 - = vtype2::permutexvar(rev_index, index_zmm_m8); - index_type index_zmm_t10 - = vtype2::permutexvar(rev_index, index_zmm_m7); - index_type index_zmm_t11 - = vtype2::permutexvar(rev_index, index_zmm_m6); - index_type index_zmm_t12 - = vtype2::permutexvar(rev_index, index_zmm_m5); - index_type index_zmm_t13 - = vtype2::permutexvar(rev_index, index_zmm_m4); - index_type index_zmm_t14 - = vtype2::permutexvar(rev_index, index_zmm_m3); - index_type index_zmm_t15 - = vtype2::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t16 - = vtype2::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); - COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); - COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); - COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); - COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); - COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); - COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); - COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + index_type index_zmm_t9 = vtype2::permutexvar(rev_index, index_zmm_m8); + index_type index_zmm_t10 = vtype2::permutexvar(rev_index, index_zmm_m7); + index_type index_zmm_t11 = vtype2::permutexvar(rev_index, index_zmm_m6); + index_type index_zmm_t12 = vtype2::permutexvar(rev_index, index_zmm_m5); + index_type index_zmm_t13 = vtype2::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t14 = vtype2::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t15 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t16 = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX( + key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX( + key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX( + key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX( + key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX( + key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX( + key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX( + key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX( + key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX( + key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); // - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + key_zmm[0] + = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] + = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, + index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, + index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, + index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, + index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, + index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, + index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, + index_zmm_t16); index_zmm[0] = index_zmm_t1; index_zmm[1] = index_zmm_t2; @@ -453,10 +458,11 @@ sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) typename vtype1::zmm_t key_zmm = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); - typename vtype2::zmm_t index_zmm = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask, indexes); - vtype1::mask_storeu( - keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); + typename vtype2::zmm_t index_zmm + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); + vtype1::mask_storeu(keys, + load_mask, + sort_zmm_64bit(key_zmm, index_zmm)); vtype2::mask_storeu(indexes, load_mask, index_zmm); } @@ -480,12 +486,12 @@ sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) zmm_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); index_type index_zmm1 = vtype2::loadu(indexes); - index_type index_zmm2 = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask, indexes + 8); + index_type index_zmm2 + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes + 8); - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( key_zmm1, key_zmm2, index_zmm1, index_zmm2); vtype2::storeu(indexes, index_zmm1); @@ -518,8 +524,8 @@ sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) index_zmm[0] = vtype2::loadu(indexes); index_zmm[1] = vtype2::loadu(indexes + 8); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; @@ -528,19 +534,19 @@ sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); - index_zmm[2] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask2, indexes + 24); + index_zmm[2] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); vtype2::storeu(indexes, index_zmm[0]); vtype2::storeu(indexes + 8, index_zmm[1]); @@ -561,7 +567,7 @@ X86_SIMD_SORT_INLINE void sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 32) { - sort_32_64bit(keys, indexes, N); + sort_32_64bit(keys, indexes, N); return; } using zmm_t = typename vtype1::zmm_t; @@ -579,10 +585,10 @@ sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) index_zmm[1] = vtype2::loadu(indexes + 8); index_zmm[2] = vtype2::loadu(indexes + 16); index_zmm[3] = vtype2::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; @@ -597,30 +603,30 @@ sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); - index_zmm[4] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - bitonic_merge_two_zmm_64bit( + index_zmm[4] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); vtype2::storeu(indexes, index_zmm[0]); vtype2::storeu(indexes + 8, index_zmm[1]); @@ -649,7 +655,7 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 64) { - sort_64_64bit(keys, indexes, N); + sort_64_64bit(keys, indexes, N); return; } using zmm_t = typename vtype1::zmm_t; @@ -675,14 +681,14 @@ sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) index_zmm[5] = vtype2::loadu(indexes + 40); index_zmm[6] = vtype2::loadu(indexes + 48); index_zmm[7] = vtype2::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; @@ -708,54 +714,54 @@ sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); - index_zmm[8] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = vtype2::mask_loadu( - vtype2::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( + index_zmm[8] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); vtype2::storeu(indexes, index_zmm[0]); vtype2::storeu(indexes + 8, index_zmm[1]); vtype2::storeu(indexes + 16, index_zmm[2]); @@ -816,12 +822,12 @@ template = 0; i--) { - heapify(keys, indexes, i, size); + heapify(keys, indexes, i, size); } for (int64_t i = size - 1; i > 0; i--) { std::swap(keys[0], keys[i]); std::swap(indexes[0], indexes[i]); - heapify(keys, indexes, 0, i); + heapify(keys, indexes, 0, i); } } @@ -840,7 +846,8 @@ void qsort_64bit_(type1_t *keys, */ if (max_iters <= 0) { //std::sort(keys+left,keys+right+1); - heap_sort(keys + left, indexes + left, right - left + 1); + heap_sort( + keys + left, indexes + left, right - left + 1); return; } /* @@ -848,7 +855,7 @@ void qsort_64bit_(type1_t *keys, */ if (right + 1 - left <= 128) { - sort_128_64bit( + sort_128_64bit( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } @@ -856,22 +863,20 @@ void qsort_64bit_(type1_t *keys, type1_t pivot = get_pivot_64bit(keys, left, right); type1_t smallest = vtype1::type_max(); type1_t biggest = vtype1::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) { - qsort_64bit_( + qsort_64bit_( keys, indexes, left, pivot_index - 1, max_iters - 1); } if (pivot != biggest) { - qsort_64bit_( + qsort_64bit_( keys, indexes, pivot_index, right, max_iters - 1); } } template <> -void avx512_qsort_kv(int64_t *keys, - uint64_t *indexes, - int64_t arrsize) +void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, zmm_vector>( @@ -891,9 +896,7 @@ void avx512_qsort_kv(uint64_t *keys, } template <> -void avx512_qsort_kv(double *keys, - uint64_t *indexes, - int64_t arrsize) +void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(keys, arrsize); diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 9ac0ccdd..6a9a8583 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -503,43 +503,41 @@ static inline int64_t partition_avx512(type_t1 *keys, // partition the current vector and save it on both sides of the array int32_t amount_gt_pivot; - amount_gt_pivot = partition_vec( - keys, - indexes, - l_store, - r_store + vtype1::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); + amount_gt_pivot + = partition_vec(keys, + indexes, + l_store, + r_store + vtype1::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); r_store -= amount_gt_pivot; l_store += (vtype1::numlanes - amount_gt_pivot); } /* partition and save vec_left and vec_right */ int32_t amount_gt_pivot; - amount_gt_pivot = partition_vec( - keys, - indexes, - l_store, - r_store + vtype1::numlanes, - keys_vec_left, - indexes_vec_left, - pivot_vec, - &min_vec, - &max_vec); + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype1::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype1::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec( - keys, - indexes, - l_store, - l_store + vtype1::numlanes, - keys_vec_right, - indexes_vec_right, - pivot_vec, - &min_vec, - &max_vec); + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + l_store + vtype1::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype1::numlanes - amount_gt_pivot); *smallest = vtype1::reducemin(min_vec); *biggest = vtype1::reducemax(max_vec); From 723dc587cab952972e7c7dece1dc9e4a54081e9f Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Wed, 26 Apr 2023 22:08:14 -0700 Subject: [PATCH 4/4] Revert "Move classes to a separate header file" This reverts commit 66ec396f3d1eba78fcd1e52118d17a343b611b39. --- src/avx512-16bit-common.h | 15 + src/avx512-16bit-qsort.hpp | 340 ++++++++++ src/avx512-32bit-qsort.hpp | 308 +++++++++ src/avx512-64bit-common.h | 316 +++++++++ src/avx512-common-qsort.h | 54 +- src/avx512-zmm-classes.h | 1147 -------------------------------- src/avx512fp16-16bit-qsort.hpp | 105 +++ 7 files changed, 1136 insertions(+), 1149 deletions(-) delete mode 100644 src/avx512-zmm-classes.h diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index cace5449..0c819946 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -14,6 +14,21 @@ * sorting network (see * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) */ +// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 +static const uint16_t network[6][32] + = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, + {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, + {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, + 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, + {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; + /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index cd9a3903..606f8706 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -9,6 +9,346 @@ #include "avx512-16bit-common.h" +struct float16 { + uint16_t val; +}; + +template <> +struct zmm_vector { + using type_t = uint16_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYH; + } + static type_t type_min() + { + return X86_SIMD_SORT_NEGINFINITYH; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi16(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); + zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); + zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); + zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); + zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); + zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); + + __mmask32 mask_ge = _mm512_cmp_epu16_mask( + sign_x, sign_y, _MM_CMPINT_LT); // only greater than + __mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y); + __mmask32 neg = _mm512_mask_cmpeq_epu16_mask( + sign_eq, + sign_x, + _mm512_set1_epi16(0x8000)); // both numbers are -ve + + // compare exponents only if signs are equal: + mask_ge = mask_ge + | _mm512_mask_cmp_epu16_mask( + sign_eq, exp_x, exp_y, _MM_CMPINT_NLE); + // get mask for elements for which both sign and exponents are equal: + __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y); + + // compare mantissa for elements for which both sign and expponent are equal: + mask_ge = mask_ge + | _mm512_mask_cmp_epu16_mask( + exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); + return _kxor_mask32(mask_ge, neg); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_mask_mov_epi16(y, ge(x, y), x); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_mask_loadu_epi16(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi16(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_mask_mov_epi16(x, ge(x, y), y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi16(idx, zmm); + } + // Apparently this is a terrible for perf, npy_half_to_float seems to work + // better + //static float uint16_to_float(uint16_t val) + //{ + // // Ideally use _mm_loadu_si16, but its only gcc > 11.x + // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM + // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val); + // __m128 xmm2 = _mm_cvtph_ps(xmm); + // return _mm_cvtss_f32(xmm2); + //} + static type_t float_to_uint16(float val) + { + __m128 xmm = _mm_load_ss(&val); + __m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC); + return _mm_extract_epi16(xmm2, 0); + } + static type_t reducemax(zmm_t v) + { + __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); + __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); + float lo_max = _mm512_reduce_max_ps(lo); + float hi_max = _mm512_reduce_max_ps(hi); + return float_to_uint16(std::max(lo_max, hi_max)); + } + static type_t reducemin(zmm_t v) + { + __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); + __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); + float lo_max = _mm512_reduce_min_ps(lo); + float hi_max = _mm512_reduce_min_ps(hi); + return float_to_uint16(std::min(lo_max, hi_max)); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi16(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); + return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + +template <> +struct zmm_vector { + using type_t = int16_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT16; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT16; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi16(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi16(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_mask_loadu_epi16(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi16(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi16(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi16(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); + type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); + return std::max(lo_max, hi_max); + } + static type_t reducemin(zmm_t v) + { + zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); + type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); + return std::min(lo_min, hi_min); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi16(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); + return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = uint16_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT16; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi16(type_max()); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu16(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi16(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi16(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi16(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu16(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi16(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); + type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); + return std::max(lo_max, hi_max); + } + static type_t reducemin(zmm_t v) + { + zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); + zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); + type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); + type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); + return std::min(lo_min, hi_min); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi16(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); + return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; + template <> bool comparison_func>(const uint16_t &a, const uint16_t &b) { diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 0f3b85a1..c4061ddf 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -23,6 +23,314 @@ #define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 #define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 +template <> +struct zmm_vector { + using type_t = int32_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi32(type_max()); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _mm512_knot(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); + } + template + static ymm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) + { + zmm_t z1 = _mm512_castsi256_si512(y1); + return _mm512_inserti32x8(z1, y2, 1); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi32(x, y); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi32(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi32(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi32(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi32(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) + { + return _mm256_max_epi32(x, y); + } + static ymm_t min(ymm_t x, ymm_t y) + { + return _mm256_min_epi32(x, y); + } +}; +template <> +struct zmm_vector { + using type_t = uint32_t; + using zmm_t = __m512i; + using ymm_t = __m256i; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + + template + static ymm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) + { + zmm_t z1 = _mm512_castsi256_si512(y1); + return _mm512_inserti32x8(z1, y2, 1); + } + static opmask_t knot_opmask(opmask_t x) + { + return _mm512_knot(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu32(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu32(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu32(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi32(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) + { + return _mm256_max_epu32(x, y); + } + static ymm_t min(ymm_t x, ymm_t y) + { + return _mm256_min_epu32(x, y); + } +}; +template <> +struct zmm_vector { + using type_t = float; + using zmm_t = __m512; + using ymm_t = __m256; + using opmask_t = __mmask16; + static const uint8_t numlanes = 16; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static zmm_t zmm_max() + { + return _mm512_set1_ps(type_max()); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _mm512_knot(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + } + template + static ymm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_ps(index, base, scale); + } + static zmm_t merge(ymm_t y1, ymm_t y2) + { + zmm_t z1 = _mm512_castsi512_ps( + _mm512_castsi256_si512(_mm256_castps_si256(y1))); + return _mm512_insertf32x8(z1, y2, 1); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_ps(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_ps(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_ps(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_ps(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_ps(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_ps(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_ps(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_ps(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_ps(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_ps(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_ps(mem, x); + } + + static ymm_t max(ymm_t x, ymm_t y) + { + return _mm256_max_ps(x, y); + } + static ymm_t min(ymm_t x, ymm_t y) + { + return _mm256_min_ps(x, y); + } +}; + /* * Assumes zmm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 87d39f15..75ae7fb1 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -19,6 +19,322 @@ #define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 #define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 +template <> +struct zmm_vector { + using type_t = int64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT64; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT64; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = uint64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT64; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = double; + using zmm_t = __m512d; + using ymm_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITY; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITY; + } + static zmm_t zmm_max() + { + return _mm512_set1_pd(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_pd(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_pd(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_pd(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_pd(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_pd(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_pd(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_pd(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_pd(mem, x); + } +}; X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) { int64_t nan_count = 0; diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 6a9a8583..b07b34d2 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -35,7 +35,58 @@ * */ -#include "avx512-zmm-classes.h" +#include +#include +#include +#include +#include +#include + +#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYH 0x7c00 +#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 +#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() +#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) +#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) +#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) +#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) +#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) +#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) +#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) +#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) +#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d + +#ifdef _MSC_VER +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __forceinline +#elif defined(__CYGWIN__) +/* + * Force inline in cygwin to work around a compiler bug. See + * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 + */ +#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#elif defined(__GNUC__) +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#else +#define X86_SIMD_SORT_INLINE static +#define X86_SIMD_SORT_FINLINE static +#endif + +template +struct zmm_vector; // Regular quicksort routines: template @@ -78,7 +129,6 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } - template diff --git a/src/avx512-zmm-classes.h b/src/avx512-zmm-classes.h deleted file mode 100644 index 45f6cb25..00000000 --- a/src/avx512-zmm-classes.h +++ /dev/null @@ -1,1147 +0,0 @@ -#ifndef AVX512_ZMM_CLASSES -#define AVX512_ZMM_CLASSES - -#include -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __forceinline -#elif defined(__CYGWIN__) -/* - * Force inline in cygwin to work around a compiler bug. See - * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 - */ -#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#elif defined(__GNUC__) -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#else -#define X86_SIMD_SORT_INLINE static -#define X86_SIMD_SORT_FINLINE static -#endif - -#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYH 0x7c00 -#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 -#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() -#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) -#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) -#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) -#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) -#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) -#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) -#define ZMM_MAX_HALF _mm512_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) -#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) -#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d - -// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 -static const uint16_t network[6][32] - = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, - 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, - {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, - {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, - 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, - {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, - 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; - -template -struct zmm_vector; - -typedef union { - _Float16 f_; - uint16_t i_; -} Fp16Bits; - -template <> -struct zmm_vector<_Float16> { - using type_t = _Float16; - using zmm_t = __m512h; - using ymm_t = __m256h; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static __m512i get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - Fp16Bits val; - val.i_ = X86_SIMD_SORT_INFINITYH; - return val.f_; - } - static type_t type_min() - { - Fp16Bits val; - val.i_ = X86_SIMD_SORT_NEGINFINITYH; - return val.f_; - } - static zmm_t zmm_max() - { - return _mm512_set1_ph(type_max()); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_ph(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_ph(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - __m512i temp = _mm512_castph_si512(x); - // AVX512_VBMI2 - return _mm512_mask_compressstoreu_epi16(mem, mask, temp); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - // AVX512BW - return _mm512_castsi512_ph( - _mm512_mask_loadu_epi16(_mm512_castph_si512(x), mask, mem)); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_castsi512_ph(_mm512_mask_mov_epi16( - _mm512_castph_si512(x), mask, _mm512_castph_si512(y))); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, _mm512_castph_si512(x)); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_ph(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_ph(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_ph(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_ph(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_ph(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512i temp = _mm512_shufflehi_epi16(_mm512_castph_si512(zmm), - (_MM_PERM_ENUM)mask); - return _mm512_castsi512_ph( - _mm512_shufflelo_epi16(temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_ph(mem, x); - } -}; - -struct float16 { - uint16_t val; -}; - -template <> -struct zmm_vector { - using type_t = uint16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static zmm_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - return X86_SIMD_SORT_INFINITYH; - } - static type_t type_min() - { - return X86_SIMD_SORT_NEGINFINITYH; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi16(type_max()); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - - static opmask_t ge(zmm_t x, zmm_t y) - { - zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); - zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); - zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); - zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); - zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); - zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); - - __mmask32 mask_ge = _mm512_cmp_epu16_mask( - sign_x, sign_y, _MM_CMPINT_LT); // only greater than - __mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y); - __mmask32 neg = _mm512_mask_cmpeq_epu16_mask( - sign_eq, - sign_x, - _mm512_set1_epi16(0x8000)); // both numbers are -ve - - // compare exponents only if signs are equal: - mask_ge = mask_ge - | _mm512_mask_cmp_epu16_mask( - sign_eq, exp_x, exp_y, _MM_CMPINT_NLE); - // get mask for elements for which both sign and exponents are equal: - __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y); - - // compare mantissa for elements for which both sign and expponent are equal: - mask_ge = mask_ge - | _mm512_mask_cmp_epu16_mask( - exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); - return _kxor_mask32(mask_ge, neg); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_mask_mov_epi16(y, ge(x, y), x); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - // AVX512_VBMI2 - return _mm512_mask_compressstoreu_epi16(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - // AVX512BW - return _mm512_mask_loadu_epi16(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi16(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_mask_mov_epi16(x, ge(x, y), y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi16(idx, zmm); - } - // Apparently this is a terrible for perf, npy_half_to_float seems to work - // better - //static float uint16_to_float(uint16_t val) - //{ - // // Ideally use _mm_loadu_si16, but its only gcc > 11.x - // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM - // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val); - // __m128 xmm2 = _mm_cvtph_ps(xmm); - // return _mm_cvtss_f32(xmm2); - //} - static type_t float_to_uint16(float val) - { - __m128 xmm = _mm_load_ss(&val); - __m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC); - return _mm_extract_epi16(xmm2, 0); - } - static type_t reducemax(zmm_t v) - { - __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); - __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); - float lo_max = _mm512_reduce_max_ps(lo); - float hi_max = _mm512_reduce_max_ps(hi); - return float_to_uint16(std::max(lo_max, hi_max)); - } - static type_t reducemin(zmm_t v) - { - __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); - __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); - float lo_max = _mm512_reduce_min_ps(lo); - float hi_max = _mm512_reduce_min_ps(hi); - return float_to_uint16(std::min(lo_max, hi_max)); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi16(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - -template <> -struct zmm_vector { - using type_t = int16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static zmm_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT16; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT16; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi16(type_max()); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi16(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - // AVX512_VBMI2 - return _mm512_mask_compressstoreu_epi16(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - // AVX512BW - return _mm512_mask_loadu_epi16(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi16(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi16(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi16(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); - type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); - return std::max(lo_max, hi_max); - } - static type_t reducemin(zmm_t v) - { - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); - type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); - return std::min(lo_min, hi_min); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi16(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - -template <> -struct zmm_vector { - using type_t = uint16_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask32; - static const uint8_t numlanes = 32; - - static zmm_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT16; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi16(type_max()); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask32(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu16(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi16(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi16(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi16(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi16(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu16(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi16(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); - type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); - return std::max(lo_max, hi_max); - } - static type_t reducemin(zmm_t v) - { - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); - type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); - type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); - return std::min(lo_min, hi_min); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi16(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - -template <> -struct zmm_vector { - using type_t = int32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask16; - static const uint8_t numlanes = 16; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT32; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT32; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi32(type_max()); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _mm512_knot(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); - } - template - static ymm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi32(index, base, scale); - } - static zmm_t merge(ymm_t y1, ymm_t y2) - { - zmm_t z1 = _mm512_castsi256_si512(y1); - return _mm512_inserti32x8(z1, y2, 1); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi32(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi32(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi32(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi32(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi32(x, y); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi32(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi32(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi32(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi32(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi32(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } - - static ymm_t max(ymm_t x, ymm_t y) - { - return _mm256_max_epi32(x, y); - } - static ymm_t min(ymm_t x, ymm_t y) - { - return _mm256_min_epi32(x, y); - } -}; - -template <> -struct zmm_vector { - using type_t = uint32_t; - using zmm_t = __m512i; - using ymm_t = __m256i; - using opmask_t = __mmask16; - static const uint8_t numlanes = 16; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT32; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? - - template - static ymm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi32(index, base, scale); - } - static zmm_t merge(ymm_t y1, ymm_t y2) - { - zmm_t z1 = _mm512_castsi256_si512(y1); - return _mm512_inserti32x8(z1, y2, 1); - } - static opmask_t knot_opmask(opmask_t x) - { - return _mm512_knot(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu32(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi32(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi32(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi32(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi32(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu32(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi32(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu32(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu32(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi32(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_epi32(zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } - - static ymm_t max(ymm_t x, ymm_t y) - { - return _mm256_max_epu32(x, y); - } - static ymm_t min(ymm_t x, ymm_t y) - { - return _mm256_min_epu32(x, y); - } -}; - -template <> -struct zmm_vector { - using type_t = float; - using zmm_t = __m512; - using ymm_t = __m256; - using opmask_t = __mmask16; - static const uint8_t numlanes = 16; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITYF; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITYF; - } - static zmm_t zmm_max() - { - return _mm512_set1_ps(type_max()); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _mm512_knot(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); - } - template - static ymm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_ps(index, base, scale); - } - static zmm_t merge(ymm_t y1, ymm_t y2) - { - zmm_t z1 = _mm512_castsi512_ps( - _mm512_castsi256_si512(_mm256_castps_si256(y1))); - return _mm512_insertf32x8(z1, y2, 1); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_ps(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_ps(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_ps(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_ps(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_ps(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_ps(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_ps(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_ps(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_ps(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_ps(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_ps(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_ps(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_ps(mem, x); - } - - static ymm_t max(ymm_t x, ymm_t y) - { - return _mm256_max_ps(x, y); - } - static ymm_t min(ymm_t x, ymm_t y) - { - return _mm256_min_ps(x, y); - } -}; - -template <> -struct zmm_vector { - using type_t = int64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT64; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT64; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - -template <> -struct zmm_vector { - using type_t = uint64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT64; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; - -template <> -struct zmm_vector { - using type_t = double; - using zmm_t = __m512d; - using ymm_t = __m512d; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITY; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITY; - } - static zmm_t zmm_max() - { - return _mm512_set1_pd(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_pd(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_pd(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_pd(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_pd(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_pd(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_pd(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_pd(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_pd(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_pd(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_pd(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_pd(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_pd(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_pd(mem, x); - } -}; - -#endif //AVX512_ZMM_CLASSES diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 2bdf9803..8a9a49ed 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -9,6 +9,111 @@ #include "avx512-16bit-common.h" +typedef union { + _Float16 f_; + uint16_t i_; +} Fp16Bits; + +template <> +struct zmm_vector<_Float16> { + using type_t = _Float16; + using zmm_t = __m512h; + using ymm_t = __m256h; + using opmask_t = __mmask32; + static const uint8_t numlanes = 32; + + static __m512i get_network(int index) + { + return _mm512_loadu_si512(&network[index - 1][0]); + } + static type_t type_max() + { + Fp16Bits val; + val.i_ = X86_SIMD_SORT_INFINITYH; + return val.f_; + } + static type_t type_min() + { + Fp16Bits val; + val.i_ = X86_SIMD_SORT_NEGINFINITYH; + return val.f_; + } + static zmm_t zmm_max() + { + return _mm512_set1_ph(type_max()); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask32(x); + } + + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_ph(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_ph(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + __m512i temp = _mm512_castph_si512(x); + // AVX512_VBMI2 + return _mm512_mask_compressstoreu_epi16(mem, mask, temp); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + // AVX512BW + return _mm512_castsi512_ph( + _mm512_mask_loadu_epi16(_mm512_castph_si512(x), mask, mem)); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_castsi512_ph(_mm512_mask_mov_epi16( + _mm512_castph_si512(x), mask, _mm512_castph_si512(y))); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi16(mem, mask, _mm512_castph_si512(x)); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_ph(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_ph(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_ph(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_ph(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_ph(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512i temp = _mm512_shufflehi_epi16(_mm512_castph_si512(zmm), + (_MM_PERM_ENUM)mask); + return _mm512_castsi512_ph( + _mm512_shufflelo_epi16(temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_ph(mem, x); + } +}; + X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr, int64_t arrsize) {