diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 81d7d00e..5588cffa 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,6 +1,8 @@ // AVX2 specific routines: #include "avx2-32bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" +#include "avx2-32bit-half.hpp" +#include "xss-common-argsort.h" #include "x86simdsort-internal.h" #define DEFINE_ALL_METHODS(type) \ @@ -18,6 +20,17 @@ void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ { \ avx2_partial_qsort(arr, k, arrsize, hasnan); \ + }\ + template <> \ + std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ + { \ + return avx2_argsort(arr, arrsize, hasnan); \ + } \ + template <> \ + std::vector argselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + return avx2_argselect(arr, k, arrsize, hasnan); \ } namespace xss { diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8ebbc6be..f088e4cd 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -189,12 +189,12 @@ DISPATCH_ALL(partial_qsort, (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argsort, (ISA_LIST("none")), - (ISA_LIST("avx512_skx")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argselect, (ISA_LIST("none")), - (ISA_LIST("avx512_skx")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\ diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp new file mode 100644 index 00000000..5a6ee5b5 --- /dev/null +++ b/src/avx2-32bit-half.hpp @@ -0,0 +1,557 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX2_HALF_32BIT +#define AVX2_HALF_32BIT + +#include "xss-common-qsort.h" +#include "avx2-emu-funcs.hpp" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) +{ + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0); + const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0); + + ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::reverse(ymm), oxCC); + ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); + return ymm; +} + +struct avx2_32bit_half_swizzle_ops; + +template <> +struct avx2_half_vector { + using type_t = int32_t; + using reg_t = __m128i; + using ymmi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static reg_t zmm_max() + { + return _mm_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static ymmi_t seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t set(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm_xor_si128(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm_cmpgt_epi32(x, y); + return _mm_castps_si128( + _mm_or_ps(_mm_castsi128_ps(equal), _mm_castsi128_ps(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_cmpeq_epi32(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi32( + src, (const int *)base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_si128((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_epi32(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t permutevar(reg_t ymm, __m128i idx) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_si128((__m128i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return v; + } + static __m128i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_half_vector { + using type_t = uint32_t; + using reg_t = __m128i; + using ymmi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm_set1_epi32(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static ymmi_t seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t set(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi32( + src, (const int *)base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static opmask_t ge(reg_t x, reg_t y) + { + reg_t maxi = max(x, y); + return eq(maxi, x); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_cmpeq_epi32(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_si128((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_epi32((const int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_epi32((int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_epu32(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t permutevar(reg_t ymm, __m128i idx) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_si128((__m128i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return v; + } + static __m128i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_half_vector { + using type_t = float; + using reg_t = __m128; + using ymmi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static reg_t zmm_max() + { + return _mm_set1_ps(type_max()); + } + + static ymmi_t seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t set(float v1, float v2, float v3, float v4) + { + return _mm_set_ps(v1, v2, v3, v4); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm_maskload_ps((const float *)mem, mask); + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_EQ_OQ)); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return convert_avx2_mask_to_int_half(mask); + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)) { + return _mm_castps_si128(_mm_cmp_ps(x, x, _CMP_UNORD_Q)); + } + else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_ps( + src, (const float *)base, index, _mm_castsi128_ps(mask), scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_blendv_ps(x, y, _mm_castsi128_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_ps(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_permutevar_ps(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m128i idx) + { + return _mm_permutevar_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_ps(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_ps((float *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return _mm_castsi128_ps(v); + } + static __m128i cast_to(reg_t v) + { + return _mm_castps_si128(v); + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; + +struct avx2_32bit_half_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m128i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m128 vf = _mm_castsi128_ps(v); + vf = _mm_permute_ps(vf, 0b10110001); + v = _mm_castps_si128(vf); + } + else if constexpr (scale == 4) { + __m128 vf = _mm_castsi128_ps(v); + vf = _mm_permute_ps(vf, 0b01001110); + v = _mm_castps_si128(vf); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m128i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m128i v1 = vtype::cast_to(reg); + __m128i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { v1 = _mm_blend_epi32(v1, v2, 0b0101); } + else if constexpr (scale == 4) { + v1 = _mm_blend_epi32(v1, v2, 0b0011); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + +#endif // AVX2_HALF_32BIT diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index 521597cd..cf0fbd55 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -70,6 +70,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -225,6 +226,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -369,6 +371,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 6ffddbde..709d98ef 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -11,15 +11,6 @@ #include "xss-common-qsort.h" #include "avx2-emu-funcs.hpp" -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ymm 3, 2, 1, 0 -#define NETWORK_64BIT_R 0, 1, 2, 3 -#define NETWORK_64BIT_1 1, 0, 3, 2 - /* * Assumes ymm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg @@ -61,6 +52,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -85,6 +77,10 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _mm256_xor_si256(x, y); @@ -107,12 +103,12 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + return _mm256_mask_i64gather_epi64( + src, (const long long int *)base, index, mask, scale); } - template - static reg_t i64gather(__m256i index, void const *base) + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_epi64((int64_t const *)base, index, scale); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -220,6 +216,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -244,17 +241,20 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + return _mm256_mask_i64gather_epi64( + src, (const long long int *)base, index, mask, scale); } - template - static reg_t i64gather(__m256i index, void const *base) + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_epi64( - (long long int const *)base, index, scale); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static opmask_t gt(reg_t x, reg_t y) { @@ -378,6 +378,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -416,7 +417,10 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) + { + return _mm256_set_pd(v1, v2, v3, v4); + } static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskload_pd((const double *)mem, mask); @@ -433,14 +437,16 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_pd( - src, base, index, _mm256_castsi256_pd(mask), scale); + return _mm256_mask_i64gather_pd(src, + (const type_t *)base, + index, + _mm256_castsi256_pd(mask), + scale); ; } - template - static reg_t i64gather(__m256i index, void const *base) + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_pd((double *)base, index, scale); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 9f6229f7..6e40d2a6 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -35,6 +35,21 @@ constexpr auto avx2_mask_helper_lut64 = [] { return lut; }(); +constexpr auto avx2_mask_helper_lut32_half = [] { + std::array, 16> lut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array entry {}; + for (int j = 0; j < 4; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + constexpr auto avx2_compressstore_lut32_gen = [] { std::array, 256>, 2> lutPair {}; auto &permLut = lutPair[0]; @@ -65,6 +80,38 @@ constexpr auto avx2_compressstore_lut32_gen = [] { constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; +constexpr auto avx2_compressstore_lut32_half_gen = [] { + std::array, 16>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0}; + int right = 3; + int left = 0; + for (int j = 0; j < 4; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_half_perm + = avx2_compressstore_lut32_half_gen[0]; +constexpr auto avx2_compressstore_lut32_half_left + = avx2_compressstore_lut32_half_gen[1]; + constexpr auto avx2_compressstore_lut64_gen = [] { std::array, 16> permLut {}; std::array, 16> leftLut {}; @@ -123,6 +170,19 @@ int32_t convert_avx2_mask_to_int_64bit(__m256i m) return _mm256_movemask_pd(_mm256_castsi256_pd(m)); } +X86_SIMD_SORT_INLINE +__m128i convert_int_to_avx2_mask_half(int32_t m) +{ + return _mm_loadu_si128( + (const __m128i *)avx2_mask_helper_lut32_half[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int_half(__m128i m) +{ + return _mm_movemask_ps(_mm_castsi128_ps(m)); +} + // Emulators for intrinsics missing from AVX2 compared to AVX512 template T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) @@ -139,6 +199,19 @@ T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) return std::max(arr[0], arr[7]); } +template +T avx2_emu_reduce_max32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::max(arr[0], arr[3]); +} + template T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { @@ -154,6 +227,19 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) return std::min(arr[0], arr[7]); } +template +T avx2_emu_reduce_min32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::min(arr[0], arr[3]); +} + template T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) { @@ -196,6 +282,29 @@ void avx2_emu_mask_compressstoreu32(void *base_addr, vtype::mask_storeu(leftStore, left, temp); } +template +void avx2_emu_mask_compressstoreu32_half( + void *base_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); + const __m128i &left = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] + .data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); +} + template void avx2_emu_mask_compressstoreu64(void *base_addr, typename avx2_vector::opmask_t k, @@ -240,6 +349,30 @@ int avx2_double_compressstore32(void *left_addr, return _mm_popcnt_u32(shortMask); } +template +int avx2_double_compressstore32_half(void *left_addr, + void *right_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + template int32_t avx2_double_compressstore64(void *left_addr, void *right_addr, diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index a71281f4..937e6ac0 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -26,6 +26,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -208,6 +209,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -343,6 +345,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 2d101b88..f06cfff0 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -41,6 +41,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -180,6 +181,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -319,6 +321,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index c4084c68..3a475da8 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,713 +7,7 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "xss-common-qsort.h" #include "avx512-64bit-common.h" -#include "xss-network-keyvaluesort.hpp" -#include - -template -X86_SIMD_SORT_INLINE void std_argselect_withnan( - T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) -{ - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - // sort indices according to corresponding array element - return arr[left] < arr[right]; - }); -} - -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ -using argtype = typename std::conditional, - zmm_vector>::type; -#else -using argtype = typename std::conditional, - zmm_vector>::type; -#endif -using argreg_t = typename argtype::reg_t; - -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - argtype::mask_compressstoreu( - arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_gt_pivot; -} -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - argreg_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - argreg_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left = vtype::i64gather(arr, arg + left); - argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec; - reg_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::i64gather(arr, arg + right); - } - else { - arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::i64gather(arr, arg + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - if (right - left <= 8 * num_unroll * vtype::numlanes) { - return partition_avx512( - arr, arg, left, right, pivot, smallest, biggest); - } - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - // first and last vtype::numlanes values are partitioned at the end - reg_t vec_left[num_unroll], vec_right[num_unroll]; - argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); - argvec_right[ii] = argtype::loadu( - arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::i64gather( - arr, arg + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec[num_unroll]; - reg_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] - = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + right + ii * vtype::numlanes); - } - } - else { - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - } - - /* partition and save vec_left and vec_right */ - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE void -argsort_8_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - using reg_t = typename vtype::reg_t; - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - argreg_t argzmm = argtype::maskz_loadu(load_mask, arg); - reg_t arrzmm = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask, argzmm, arr); - arrzmm = sort_zmm_64bit(arrzmm, argzmm); - argtype::mask_storeu(arg, load_mask, argzmm); -} - -template -X86_SIMD_SORT_INLINE void -argsort_16_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 8) { - argsort_8_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - argreg_t argzmm1 = argtype::loadu(arg); - argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); - reg_t arrzmm1 = vtype::i64gather(arr, arg); - reg_t arrzmm2 = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask, argzmm2, arr); - arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); - arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); - bitonic_merge_two_zmm_64bit( - arrzmm1, arrzmm2, argzmm1, argzmm2); - argtype::storeu(arg, argzmm1); - argtype::mask_storeu(arg + 8, load_mask, argzmm2); -} - -template -X86_SIMD_SORT_INLINE void -argsort_32_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 16) { - argsort_16_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - reg_t arrzmm[4]; - argreg_t argzmm[4]; - - X86_SIMD_SORT_UNROLL_LOOP(2) - for (int ii = 0; ii < 2; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); - arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); - } - - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - opmask_t load_mask[2] = {0xFF, 0xFF}; - X86_SIMD_SORT_UNROLL_LOOP(2) - for (int ii = 0; ii < 2; ++ii) { - load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; - argzmm[ii + 2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8 * ii); - arrzmm[ii + 2] = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask[ii], argzmm[ii + 2], arr); - arrzmm[ii + 2] = sort_zmm_64bit(arrzmm[ii + 2], - argzmm[ii + 2]); - } - - bitonic_merge_two_zmm_64bit( - arrzmm[0], arrzmm[1], argzmm[0], argzmm[1]); - bitonic_merge_two_zmm_64bit( - arrzmm[2], arrzmm[3], argzmm[2], argzmm[3]); - bitonic_merge_four_zmm_64bit(arrzmm, argzmm); - - argtype::storeu(arg, argzmm[0]); - argtype::storeu(arg + 8, argzmm[1]); - argtype::mask_storeu(arg + 16, load_mask[0], argzmm[2]); - argtype::mask_storeu(arg + 24, load_mask[1], argzmm[3]); -} - -template -X86_SIMD_SORT_INLINE void -argsort_64_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 32) { - argsort_32_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - reg_t arrzmm[8]; - argreg_t argzmm[8]; - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); - arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); - } - - opmask_t load_mask[4] = {0xFF, 0xFF, 0xFF, 0xFF}; - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; - argzmm[ii + 4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8 * ii); - arrzmm[ii + 4] = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask[ii], argzmm[ii + 4], arr); - arrzmm[ii + 4] = sort_zmm_64bit(arrzmm[ii + 4], - argzmm[ii + 4]); - } - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 8; ii = ii + 2) { - bitonic_merge_two_zmm_64bit( - arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); - } - bitonic_merge_four_zmm_64bit(arrzmm, argzmm); - bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); - bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argtype::storeu(arg + 8 * ii, argzmm[ii]); - } - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argtype::mask_storeu(arg + 32 + 8 * ii, load_mask[ii], argzmm[ii + 4]); - } -} - -/* arsort 128 doesn't seem to make much of a difference to perf*/ -//template -//X86_SIMD_SORT_INLINE void -//argsort_128_64bit(type_t *arr, arrsize_t *arg, int32_t N) -//{ -// if (N <= 64) { -// argsort_64_64bit(arr, arg, N); -// return; -// } -// using reg_t = typename vtype::reg_t; -// using opmask_t = typename vtype::opmask_t; -// reg_t arrzmm[16]; -// argreg_t argzmm[16]; -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argzmm[ii] = argtype::loadu(arg + 8*ii); -// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr); -// arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); -// } -// -// opmask_t load_mask[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; -// if (N != 128) { -// uarrsize_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; -// } -// } -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argzmm[ii+8] = argtype::maskz_loadu(load_mask[ii], arg + 64 + 8*ii); -// arrzmm[ii+8] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+8], arr); -// arrzmm[ii+8] = sort_zmm_64bit(arrzmm[ii+8], argzmm[ii+8]); -// } -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 16; ii = ii + 2) { -// bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); -// } -// bitonic_merge_four_zmm_64bit(arrzmm, argzmm); -// bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); -// bitonic_merge_four_zmm_64bit(arrzmm + 8, argzmm + 8); -// bitonic_merge_four_zmm_64bit(arrzmm + 12, argzmm + 12); -// bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); -// bitonic_merge_eight_zmm_64bit(arrzmm+8, argzmm+8); -// bitonic_merge_sixteen_zmm_64bit(arrzmm, argzmm); -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argtype::storeu(arg + 8*ii, argzmm[ii]); -// } -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argtype::mask_storeu(arg + 64 + 8*ii, load_mask[ii], argzmm[ii + 8]); -// } -//} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - arrsize_t *arg, - const arrsize_t left, - const arrsize_t right) -{ - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; - } -} - -template -X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 64) { - argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); -} - -template -X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 64) { - argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( - arr, arg, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( - arr, arg, pos, pivot_index, right, max_iters - 1); -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) -{ - using vectype = typename std::conditional, - zmm_vector>::type; - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize, hasnan); - return indices; -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - using vectype = typename std::conditional, - zmm_vector>::type; - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argselect(arr, indices.data(), k, arrsize, hasnan); - return indices; -} - -/* To maintain compatibility with NumPy build */ -template -X86_SIMD_SORT_INLINE void -avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) -{ - avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); -} - -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) -{ - avx512_argsort(arr, reinterpret_cast(arg), arrsize); -} +#include "xss-common-argsort.h" #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 909f3b2b..65ee85db 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -8,6 +8,7 @@ #define AVX512_64BIT_COMMON #include "xss-common-includes.h" +#include "avx2-32bit-qsort.hpp" /* * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic @@ -32,6 +33,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -85,6 +87,10 @@ struct ymm_vector { { return ((0x1ull << num_to_read) - 0x1ull); } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; + } template static opmask_t fpclass(reg_t x) { @@ -194,6 +200,19 @@ struct ymm_vector { { _mm256_storeu_ps((float *)mem, x); } + static reg_t cast_from(__m256i v) + { + return _mm256_castsi256_ps(v); + } + static __m256i cast_to(reg_t v) + { + return _mm256_castps_si256(v); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct ymm_vector { @@ -202,6 +221,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -354,6 +374,19 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct ymm_vector { @@ -362,6 +395,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -514,6 +548,19 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct zmm_vector { @@ -529,6 +576,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -707,6 +755,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -877,6 +926,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 55f79bb1..1f446c68 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -182,356 +182,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys, return l_store; } -template -X86_SIMD_SORT_INLINE void -sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype1::reg_t key_zmm - = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); - - typename vtype2::reg_t index_zmm - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); - vtype1::mask_storeu(keys, - load_mask, - sort_zmm_64bit(key_zmm, index_zmm)); - vtype2::mask_storeu(indexes, load_mask, index_zmm); -} - -template -X86_SIMD_SORT_INLINE void -sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 8) { - sort_8_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using index_type = typename vtype2::reg_t; - - typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - - reg_t key_zmm1 = vtype1::loadu(keys); - reg_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); - - index_type index_zmm1 = vtype2::loadu(indexes); - index_type index_zmm2 - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes + 8); - - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( - key_zmm1, key_zmm2, index_zmm1, index_zmm2); - - vtype2::storeu(indexes, index_zmm1); - vtype2::mask_storeu(indexes + 8, load_mask, index_zmm2); - - vtype1::storeu(keys, key_zmm1); - vtype1::mask_storeu(keys + 8, load_mask, key_zmm2); -} - -template -X86_SIMD_SORT_INLINE void -sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 16) { - sort_16_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using opmask_t = typename vtype2::opmask_t; - using index_type = typename vtype2::reg_t; - reg_t key_zmm[4]; - index_type index_zmm[4]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); - key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); - - index_zmm[2] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 24); - - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - vtype2::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::mask_storeu(keys + 16, load_mask1, key_zmm[2]); - vtype1::mask_storeu(keys + 24, load_mask2, key_zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void -sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 32) { - sort_32_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using opmask_t = typename vtype1::opmask_t; - using index_type = typename vtype2::reg_t; - reg_t key_zmm[8]; - index_type index_zmm[8]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - key_zmm[2] = vtype1::loadu(keys + 16); - key_zmm[3] = vtype1::loadu(keys + 24); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - index_zmm[2] = vtype2::loadu(indexes + 16); - index_zmm[3] = vtype2::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - // N-32 >= 1 - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - key_zmm[4] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 32); - key_zmm[5] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 40); - key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); - key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); - - index_zmm[4] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::storeu(indexes + 16, index_zmm[2]); - vtype2::storeu(indexes + 24, index_zmm[3]); - vtype2::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - vtype2::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - vtype2::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - vtype2::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::storeu(keys + 16, key_zmm[2]); - vtype1::storeu(keys + 24, key_zmm[3]); - vtype1::mask_storeu(keys + 32, load_mask1, key_zmm[4]); - vtype1::mask_storeu(keys + 40, load_mask2, key_zmm[5]); - vtype1::mask_storeu(keys + 48, load_mask3, key_zmm[6]); - vtype1::mask_storeu(keys + 56, load_mask4, key_zmm[7]); -} - -template -X86_SIMD_SORT_INLINE void -sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 64) { - sort_64_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using index_type = typename vtype2::reg_t; - using opmask_t = typename vtype1::opmask_t; - reg_t key_zmm[16]; - index_type index_zmm[16]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - key_zmm[2] = vtype1::loadu(keys + 16); - key_zmm[3] = vtype1::loadu(keys + 24); - key_zmm[4] = vtype1::loadu(keys + 32); - key_zmm[5] = vtype1::loadu(keys + 40); - key_zmm[6] = vtype1::loadu(keys + 48); - key_zmm[7] = vtype1::loadu(keys + 56); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - index_zmm[2] = vtype2::loadu(indexes + 16); - index_zmm[3] = vtype2::loadu(indexes + 24); - index_zmm[4] = vtype2::loadu(indexes + 32); - index_zmm[5] = vtype2::loadu(indexes + 40); - index_zmm[6] = vtype2::loadu(indexes + 48); - index_zmm[7] = vtype2::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - } - key_zmm[8] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 64); - key_zmm[9] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 72); - key_zmm[10] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 80); - key_zmm[11] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 88); - key_zmm[12] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask5, keys + 96); - key_zmm[13] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask6, keys + 104); - key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); - key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); - - index_zmm[8] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( - key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( - key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( - key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( - key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::storeu(indexes + 16, index_zmm[2]); - vtype2::storeu(indexes + 24, index_zmm[3]); - vtype2::storeu(indexes + 32, index_zmm[4]); - vtype2::storeu(indexes + 40, index_zmm[5]); - vtype2::storeu(indexes + 48, index_zmm[6]); - vtype2::storeu(indexes + 56, index_zmm[7]); - vtype2::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - vtype2::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - vtype2::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - vtype2::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - vtype2::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - vtype2::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - vtype2::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - vtype2::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::storeu(keys + 16, key_zmm[2]); - vtype1::storeu(keys + 24, key_zmm[3]); - vtype1::storeu(keys + 32, key_zmm[4]); - vtype1::storeu(keys + 40, key_zmm[5]); - vtype1::storeu(keys + 48, key_zmm[6]); - vtype1::storeu(keys + 56, key_zmm[7]); - vtype1::mask_storeu(keys + 64, load_mask1, key_zmm[8]); - vtype1::mask_storeu(keys + 72, load_mask2, key_zmm[9]); - vtype1::mask_storeu(keys + 80, load_mask3, key_zmm[10]); - vtype1::mask_storeu(keys + 88, load_mask4, key_zmm[11]); - vtype1::mask_storeu(keys + 96, load_mask5, key_zmm[12]); - vtype1::mask_storeu(keys + 104, load_mask6, key_zmm[13]); - vtype1::mask_storeu(keys + 112, load_mask7, key_zmm[14]); - vtype1::mask_storeu(keys + 120, load_mask8, key_zmm[15]); -} - template ( + kvsort_n( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 21958027..f44209fa 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -23,6 +23,7 @@ struct zmm_vector<_Float16> { static const uint8_t numlanes = 32; static constexpr int network_sort_threshold = 128; static constexpr int partition_unroll_factor = 0; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h new file mode 100644 index 00000000..67aa2002 --- /dev/null +++ b/src/xss-common-argsort.h @@ -0,0 +1,705 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef XSS_COMMON_ARGSORT +#define XSS_COMMON_ARGSORT + +#include "xss-common-qsort.h" +#include "xss-network-keyvaluesort.hpp" +#include + +template +X86_SIMD_SORT_INLINE void std_argselect_withnan( + T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](arrsize_t a, arrsize_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ +using argtypeAVX512 = + typename std::conditional, + zmm_vector>::type; +#else +using argtypeAVX512 = + typename std::conditional, + zmm_vector>::type; +#endif + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + argtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} + +/* + * Parition one AVX2 register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); + typename argtype::opmask_t ge_mask + = extend_mask(ge_mask_vtype); + + auto l_store = arg + left; + auto r_store = arg + right - vtype::numlanes; + + int amount_ge_pivot + = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); + + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + + return amount_ge_pivot; +} + +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { + return partition_vec_avx512(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else if constexpr (vtype::vec_type == simd_type::AVX2) { + return partition_vec_avx2(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else { + static_assert(sizeof(argreg_t) == 0, "Should not get here"); + } +} + +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argreg_t argvec = argtype::loadu(arg + left); + reg_t vec = vtype::i64gather(arr, arg + left); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argreg_t argvec_left = argtype::loadu(arg + left); + reg_t vec_left = vtype::i64gather(arr, arg + left); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec; + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::i64gather(arr, arg + right); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::i64gather(arr, arg + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left[num_unroll], vec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec[num_unroll]; + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); + } + } + else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + arrsize_t *arg, + const arrsize_t left, + const arrsize_t right) +{ + if constexpr (vtype::numlanes == 8) { + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } + } + else if constexpr (vtype::numlanes == 4) { + if (right - left >= vtype::numlanes) { + // median of 4 + arrsize_t size = (right - left) / 4; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[2]; + } + else { + return arr[arg[right]]; + } + } +} + +template +X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + using argtype = typename std::conditional, + zmm_vector>::type; + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); +} + +template +X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + using argtype = typename std::conditional, + zmm_vector>::type; + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_64bit_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_64bit_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + zmm_vector>::type; + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + zmm_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +/* To maintain compatibility with NumPy build */ +template +X86_SIMD_SORT_INLINE void +avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) +{ + avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); +} + +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) +{ + avx512_argsort(arr, reinterpret_cast(arg), arrsize); +} + +#endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index c373ba54..9f793e37 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -82,4 +82,9 @@ struct ymm_vector; template struct avx2_vector; +template +struct avx2_half_vector; + +enum class simd_type : int { AVX2, AVX512 }; + #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index e76d9f6a..097efceb 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -87,7 +87,8 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) else { in = vtype::loadu(arr + ii); } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + auto nanmask = vtype::convert_mask_to_int( + vtype::template fpclass<0x01 | 0x80>(in)); if (nanmask != 0x00) { found_nan = true; break; diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index cec1cb7a..334cb560 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,5 +1,34 @@ -#ifndef AVX512_KEYVALUE_NETWORKS -#define AVX512_KEYVALUE_NETWORKS +#ifndef XSS_KEYVALUE_NETWORKS +#define XSS_KEYVALUE_NETWORKS + +#include "xss-common-includes.h" + +template +struct index_64bit_vector_type; +template <> +struct index_64bit_vector_type<8> { + using type = zmm_vector; +}; +template <> +struct index_64bit_vector_type<4> { + using type = avx2_vector; +}; + +template +typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) +{ + if constexpr (keyType::vec_type == simd_type::AVX512) { return mask; } + else if constexpr (keyType::vec_type == simd_type::AVX2) { + if constexpr (sizeof(mask) == 32) { return mask; } + else { + return _mm256_cvtepi32_epi64(mask); + } + } + else { + static_assert(keyType::vec_type == simd_type::AVX512, + "Should not reach here"); + } +} template (vtype1::eq(key_t1, key1)); + + reg_t2 index_t1 = vtype2::mask_mov(index2, eqMask, index1); + reg_t2 index_t2 = vtype2::mask_mov(index1, eqMask, index2); key1 = key_t1; key2 = key_t2; @@ -34,10 +63,24 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, opmask_t mask) { reg_t1 tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); + indexes1 = vtype2::mask_mov( + indexes2, + extend_mask(vtype1::eq(tmp_keys, in1)), + indexes1); return tmp_keys; // 0 -> min, 1 -> max } +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + template +X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t key_zmm, index_type &index_zmm) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); + + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + key_zmm = cmp_merge(key_zmm, + vtype1::reverse(key_zmm), + index_zmm, + vtype2::reverse(index_zmm), + oxCC); + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + return key_zmm; +} + // Assumes zmm is bitonic and performs a recursive half cleaner template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(reg_t &key_zmm1, - reg_t &key_zmm2, - index_type &index_zmm1, - index_type &index_zmm2) +X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_64bit(reg_t key_zmm, + index_type &index_zmm) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); - index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); + + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + return key_zmm; +} + +template +X86_SIMD_SORT_INLINE void +bitonic_merge_dispatch(typename keyType::reg_t &key, + typename valueType::reg_t &value) +{ + constexpr int numlanes = keyType::numlanes; + if constexpr (numlanes == 8) { + key = bitonic_merge_zmm_64bit(key, value); + } + else if constexpr (numlanes == 4) { + key = bitonic_merge_ymm_64bit(key, value); + } + else { + static_assert(numlanes == -1, "should not reach here"); + UNUSED(key); + UNUSED(value); + } +} - reg_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); - reg_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); +template +X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, + typename valueType::reg_t &value) +{ + constexpr int numlanes = keyType::numlanes; + if constexpr (numlanes == 8) { + key = sort_zmm_64bit(key, value); + } + else if constexpr (numlanes == 4) { + key = sort_ymm_64bit(key, value); + } + else { + static_assert(numlanes == -1, "should not reach here"); + UNUSED(key); + UNUSED(value); + } +} - typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1); +template +X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int num = numVecs / 2; num >= 2; num /= 2) { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int j = 0; j < numVecs; j += num) { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < num / 2; i++) { + arrsize_t index1 = i + j; + arrsize_t index2 = i + j + num / 2; + COEX(keys[index1], + keys[index2], + values[index1], + values[index2]); + } + } + } +} - index_type index_zmm3 = vtype2::mask_mov(index_zmm2, movmask, index_zmm1); - index_type index_zmm4 = vtype2::mask_mov(index_zmm1, movmask, index_zmm2); +template +X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ + // Do the reverse part + if constexpr (numVecs == 2) { + keys[1] = keyType::reverse(keys[1]); + values[1] = valueType::reverse(values[1]); + COEX(keys[0], keys[1], values[0], values[1]); + keys[1] = keyType::reverse(keys[1]); + values[1] = valueType::reverse(values[1]); + } + else if constexpr (numVecs > 2) { + // Reverse upper half + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); + values[numVecs - i - 1] + = valueType::reverse(values[numVecs - i - 1]); + + COEX(keys[i], + keys[numVecs - i - 1], + values[i], + values[numVecs - i - 1]); + + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); + values[numVecs - i - 1] + = valueType::reverse(values[numVecs - i - 1]); + } + } + + // Call cleaner + bitonic_clean_n_vec(keys, values); + + // Now do bitonic_merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + bitonic_merge_dispatch(keys[i], values[i]); + } +} - /* need to reverse the lower registers to keep the correct order */ - key_zmm4 = vtype1::permutexvar(rev_index1, key_zmm4); - index_zmm4 = vtype2::permutexvar(rev_index2, index_zmm4); +template +X86_SIMD_SORT_INLINE void +bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ + if constexpr (numPer > numVecs) { + UNUSED(keys); + UNUSED(values); + return; + } + else { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / numPer; i++) { + bitonic_merge_n_vec( + keys + i * numPer, values + i * numPer); + } + bitonic_fullmerge_n_vec( + keys, values); + } +} - // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); - index_zmm1 = index_zmm3; - index_zmm2 = index_zmm4; +template +X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, + typename indexType::type_t *indices, + int N) +{ + using kreg_t = typename keyType::reg_t; + using ireg_t = typename indexType::reg_t; + + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * keyType::numlanes) { + argsort_n_vec(keys, indices, N); + return; + } + } + + kreg_t keyVecs[numVecs]; + ireg_t indexVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + indexVecs[i] = indexType::loadu(indices + i * indexType::numlanes); + keyVecs[i] + = keyType::i64gather(keys, indices + i * indexType::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2; i < numVecs; i++) { + indexVecs[i] = indexType::mask_loadu( + indexType::zmm_max(), + extend_mask(ioMasks[i - numVecs / 2]), + indices + i * indexType::numlanes); + + keyVecs[i] = keyType::template mask_i64gather(keyType::zmm_max(), + ioMasks[i - numVecs / 2], + indexVecs[i], + keys); + } + + // Sort each loaded vector + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + sort_vec_dispatch(keyVecs[i], indexVecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(keyVecs, indexVecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + indexType::storeu(indices + i * indexType::numlanes, indexVecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + indexType::mask_storeu( + indices + i * indexType::numlanes, + extend_mask(ioMasks[i - numVecs / 2]), + indexVecs[i]); + } } -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) + +template +X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, + typename valueType::type_t *values, + int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - // 1) First step of a merging network - reg_t key_zmm2r = vtype1::permutexvar(rev_index1, key_zmm[2]); - reg_t key_zmm3r = vtype1::permutexvar(rev_index1, key_zmm[3]); - index_type index_zmm2r = vtype2::permutexvar(rev_index2, index_zmm[2]); - index_type index_zmm3r = vtype2::permutexvar(rev_index2, index_zmm[3]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm3r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm2r); - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); - - typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - - index_type index_reg_t1 - = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); - index_type index_zmm_m1 - = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); - index_type index_reg_t2 - = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); - index_type index_zmm_m2 - = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); - - // 2) Recursive half clearer: 16 - reg_t key_reg_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - reg_t key_zmm0 = vtype1::min(key_reg_t1, key_reg_t2); - reg_t key_zmm1 = vtype1::max(key_reg_t1, key_reg_t2); - reg_t key_zmm2 = vtype1::min(key_reg_t3, key_reg_t4); - reg_t key_zmm3 = vtype1::max(key_reg_t3, key_reg_t4); - - movmask1 = vtype1::eq(key_zmm0, key_reg_t1); - movmask2 = vtype1::eq(key_zmm2, key_reg_t3); - - index_type index_zmm0 - = vtype2::mask_mov(index_reg_t2, movmask1, index_reg_t1); - index_type index_zmm1 - = vtype2::mask_mov(index_reg_t1, movmask1, index_reg_t2); - index_type index_zmm2 - = vtype2::mask_mov(index_reg_t4, movmask2, index_reg_t3); - index_type index_zmm3 - = vtype2::mask_mov(index_reg_t3, movmask2, index_reg_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - - index_zmm[0] = index_zmm0; - index_zmm[1] = index_zmm1; - index_zmm[2] = index_zmm2; - index_zmm[3] = index_zmm3; + using kreg_t = typename keyType::reg_t; + using vreg_t = typename valueType::reg_t; + + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * keyType::numlanes) { + kvsort_n_vec(keys, values, N); + return; + } + } + + kreg_t keyVecs[numVecs]; + vreg_t valueVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keyVecs[i] = keyType::loadu(keys + i * keyType::numlanes); + valueVecs[i] = valueType::loadu(values + i * valueType::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + keyVecs[i] = keyType::mask_loadu( + keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); + valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), + ioMasks[j], + values + i * valueType::numlanes); + } + + // Sort each loaded vector + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + sort_vec_dispatch(keyVecs[i], valueVecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(keyVecs, valueVecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keyType::storeu(keys + i * keyType::numlanes, keyVecs[i]); + valueType::storeu(values + i * valueType::numlanes, valueVecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + keyType::mask_storeu( + keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); + valueType::mask_storeu( + values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); + } } -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) +template +X86_SIMD_SORT_INLINE void +argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - reg_t key_zmm4r = vtype1::permutexvar(rev_index1, key_zmm[4]); - reg_t key_zmm5r = vtype1::permutexvar(rev_index1, key_zmm[5]); - reg_t key_zmm6r = vtype1::permutexvar(rev_index1, key_zmm[6]); - reg_t key_zmm7r = vtype1::permutexvar(rev_index1, key_zmm[7]); - index_type index_zmm4r = vtype2::permutexvar(rev_index2, index_zmm[4]); - index_type index_zmm5r = vtype2::permutexvar(rev_index2, index_zmm[5]); - index_type index_zmm6r = vtype2::permutexvar(rev_index2, index_zmm[6]); - index_type index_zmm7r = vtype2::permutexvar(rev_index2, index_zmm[7]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm7r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm6r); - reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm5r); - reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm4r); - - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); - reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); - reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); - - typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - typename vtype1::opmask_t movmask3 = vtype1::eq(key_reg_t3, key_zmm[2]); - typename vtype1::opmask_t movmask4 = vtype1::eq(key_reg_t4, key_zmm[3]); - - index_type index_reg_t1 - = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); - index_type index_zmm_m1 - = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); - index_type index_reg_t2 - = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); - index_type index_zmm_m2 - = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); - index_type index_reg_t3 - = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); - index_type index_zmm_m3 - = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); - index_type index_reg_t4 - = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); - index_type index_zmm_m4 - = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); - - reg_t key_reg_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_reg_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_reg_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_reg_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_reg_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t8 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); - COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); - COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); - COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); - COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); - COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); - COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); - COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); - key_zmm[0] - = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); - - index_zmm[0] = index_reg_t1; - index_zmm[1] = index_reg_t2; - index_zmm[2] = index_reg_t3; - index_zmm[3] = index_reg_t4; - index_zmm[4] = index_reg_t5; - index_zmm[5] = index_reg_t6; - index_zmm[6] = index_reg_t7; - index_zmm[7] = index_reg_t8; + using indexType = typename index_64bit_vector_type::type; + + static_assert(keyType::numlanes == indexType::numlanes, + "invalid pairing of value/index types"); + constexpr int numVecs = maxN / keyType::numlanes; + constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be keyType::numlanes times a power of 2"); + + argsort_n_vec(keys, indices, N); } -template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) +template +X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, + typename valueType::type_t *values, + int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - reg_t key_zmm8r = vtype1::permutexvar(rev_index1, key_zmm[8]); - reg_t key_zmm9r = vtype1::permutexvar(rev_index1, key_zmm[9]); - reg_t key_zmm10r = vtype1::permutexvar(rev_index1, key_zmm[10]); - reg_t key_zmm11r = vtype1::permutexvar(rev_index1, key_zmm[11]); - reg_t key_zmm12r = vtype1::permutexvar(rev_index1, key_zmm[12]); - reg_t key_zmm13r = vtype1::permutexvar(rev_index1, key_zmm[13]); - reg_t key_zmm14r = vtype1::permutexvar(rev_index1, key_zmm[14]); - reg_t key_zmm15r = vtype1::permutexvar(rev_index1, key_zmm[15]); - - index_type index_zmm8r = vtype2::permutexvar(rev_index2, index_zmm[8]); - index_type index_zmm9r = vtype2::permutexvar(rev_index2, index_zmm[9]); - index_type index_zmm10r = vtype2::permutexvar(rev_index2, index_zmm[10]); - index_type index_zmm11r = vtype2::permutexvar(rev_index2, index_zmm[11]); - index_type index_zmm12r = vtype2::permutexvar(rev_index2, index_zmm[12]); - index_type index_zmm13r = vtype2::permutexvar(rev_index2, index_zmm[13]); - index_type index_zmm14r = vtype2::permutexvar(rev_index2, index_zmm[14]); - index_type index_zmm15r = vtype2::permutexvar(rev_index2, index_zmm[15]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm15r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm14r); - reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm13r); - reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm12r); - reg_t key_reg_t5 = vtype1::min(key_zmm[4], key_zmm11r); - reg_t key_reg_t6 = vtype1::min(key_zmm[5], key_zmm10r); - reg_t key_reg_t7 = vtype1::min(key_zmm[6], key_zmm9r); - reg_t key_reg_t8 = vtype1::min(key_zmm[7], key_zmm8r); - - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); - reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); - reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); - reg_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); - reg_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); - reg_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); - reg_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); - - index_type index_reg_t1 = vtype2::mask_mov( - index_zmm15r, vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm15r); - index_type index_reg_t2 = vtype2::mask_mov( - index_zmm14r, vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm14r); - index_type index_reg_t3 = vtype2::mask_mov( - index_zmm13r, vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm13r); - index_type index_reg_t4 = vtype2::mask_mov( - index_zmm12r, vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm12r); - - index_type index_reg_t5 = vtype2::mask_mov( - index_zmm11r, vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm[4]); - index_type index_zmm_m5 = vtype2::mask_mov( - index_zmm[4], vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm11r); - index_type index_reg_t6 = vtype2::mask_mov( - index_zmm10r, vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm[5]); - index_type index_zmm_m6 = vtype2::mask_mov( - index_zmm[5], vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm10r); - index_type index_reg_t7 = vtype2::mask_mov( - index_zmm9r, vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm[6]); - index_type index_zmm_m7 = vtype2::mask_mov( - index_zmm[6], vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm9r); - index_type index_reg_t8 = vtype2::mask_mov( - index_zmm8r, vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm[7]); - index_type index_zmm_m8 = vtype2::mask_mov( - index_zmm[7], vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm8r); - - reg_t key_reg_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); - reg_t key_reg_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); - reg_t key_reg_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); - reg_t key_reg_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); - reg_t key_reg_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_reg_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_reg_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t9 = vtype2::permutexvar(rev_index2, index_zmm_m8); - index_type index_reg_t10 = vtype2::permutexvar(rev_index2, index_zmm_m7); - index_type index_reg_t11 = vtype2::permutexvar(rev_index2, index_zmm_m6); - index_type index_reg_t12 = vtype2::permutexvar(rev_index2, index_zmm_m5); - index_type index_reg_t13 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_reg_t14 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_reg_t15 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t16 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_reg_t1, key_reg_t5, index_reg_t1, index_reg_t5); - COEX(key_reg_t2, key_reg_t6, index_reg_t2, index_reg_t6); - COEX(key_reg_t3, key_reg_t7, index_reg_t3, index_reg_t7); - COEX(key_reg_t4, key_reg_t8, index_reg_t4, index_reg_t8); - COEX(key_reg_t9, key_reg_t13, index_reg_t9, index_reg_t13); - COEX( - key_reg_t10, key_reg_t14, index_reg_t10, index_reg_t14); - COEX( - key_reg_t11, key_reg_t15, index_reg_t11, index_reg_t15); - COEX( - key_reg_t12, key_reg_t16, index_reg_t12, index_reg_t16); - - COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); - COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); - COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); - COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); - COEX(key_reg_t9, key_reg_t11, index_reg_t9, index_reg_t11); - COEX( - key_reg_t10, key_reg_t12, index_reg_t10, index_reg_t12); - COEX( - key_reg_t13, key_reg_t15, index_reg_t13, index_reg_t15); - COEX( - key_reg_t14, key_reg_t16, index_reg_t14, index_reg_t16); - - COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); - COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); - COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); - COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); - COEX(key_reg_t9, key_reg_t10, index_reg_t9, index_reg_t10); - COEX( - key_reg_t11, key_reg_t12, index_reg_t11, index_reg_t12); - COEX( - key_reg_t13, key_reg_t14, index_reg_t13, index_reg_t14); - COEX( - key_reg_t15, key_reg_t16, index_reg_t15, index_reg_t16); - // - key_zmm[0] - = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); - key_zmm[8] - = bitonic_merge_zmm_64bit(key_reg_t9, index_reg_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_reg_t10, - index_reg_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_reg_t11, - index_reg_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_reg_t12, - index_reg_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_reg_t13, - index_reg_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_reg_t14, - index_reg_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_reg_t15, - index_reg_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_reg_t16, - index_reg_t16); - - index_zmm[0] = index_reg_t1; - index_zmm[1] = index_reg_t2; - index_zmm[2] = index_reg_t3; - index_zmm[3] = index_reg_t4; - index_zmm[4] = index_reg_t5; - index_zmm[5] = index_reg_t6; - index_zmm[6] = index_reg_t7; - index_zmm[7] = index_reg_t8; - index_zmm[8] = index_reg_t9; - index_zmm[9] = index_reg_t10; - index_zmm[10] = index_reg_t11; - index_zmm[11] = index_reg_t12; - index_zmm[12] = index_reg_t13; - index_zmm[13] = index_reg_t14; - index_zmm[14] = index_reg_t15; - index_zmm[15] = index_reg_t16; + static_assert(keyType::numlanes == valueType::numlanes, + "invalid pairing of key/value types"); + + constexpr int numVecs = maxN / keyType::numlanes; + constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be keyType::numlanes times a power of 2"); + + kvsort_n_vec(keys, values, N); } -#endif // AVX512_KEYVALUE_NETWORKS + +#endif \ No newline at end of file