diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..669f6732 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +**/.vscode + diff --git a/Makefile b/Makefile index 938dbe5b..07c7818d 100644 --- a/Makefile +++ b/Makefile @@ -24,4 +24,4 @@ bench: $(BENCHDIR)/main.cpp $(SRCS) $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe clean: - rm -f $(TESTDIR)/*.o testexe benchexe + rm -f $(TESTDIR)/*.o testexe benchexe \ No newline at end of file diff --git a/benchmarks/bench-tgl.out b/benchmarks/bench-tgl.out index a1e0bcf2..1bb03936 100644 --- a/benchmarks/bench-tgl.out +++ b/benchmarks/bench-tgl.out @@ -25,4 +25,4 @@ | uniform random | int16_t | 10000 | 84703 | 1547726 | 18.3 | | uniform random | int16_t | 100000 | 1442726 | 19705242 | 13.7 | | uniform random | int16_t | 1000000 | 20210224 | 212137465 | 10.5 | -|-----------------+-------------+------------+-----------------+-----------+----------| +|-----------------+-------------+------------+-----------------+-----------+----------| \ No newline at end of file diff --git a/benchmarks/bench.hpp b/benchmarks/bench.hpp index 0837a6ca..d54f61ac 100644 --- a/benchmarks/bench.hpp +++ b/benchmarks/bench.hpp @@ -5,12 +5,19 @@ #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-keyvaluesort.hpp" #include "avx512-64bit-qsort.hpp" #include #include #include #include +template +struct sorted_t { + K key; + V value; +}; + static inline uint64_t cycles_start(void) { unsigned a, d; @@ -72,3 +79,50 @@ std::tuple bench_sort(const std::vector arr, / lastfew; return std::make_tuple(avx_sort, std_sort); } + +template +std::tuple +bench_sort_kv(const std::vector keys, + const std::vector values, + const std::vector> sortedaar, + const uint64_t iters, + const uint64_t lastfew) +{ + + std::vector keys_bckup = keys; + std::vector values_bckup = values; + std::vector> sortedaar_bckup = sortedaar; + + std::vector runtimes1, runtimes2; + uint64_t start(0), end(0); + for (uint64_t ii = 0; ii < iters; ++ii) { + start = cycles_start(); + avx512_qsort_kv( + keys_bckup.data(), values_bckup.data(), keys_bckup.size()); + end = cycles_end(); + runtimes1.emplace_back(end - start); + keys_bckup = keys; + values_bckup = values; + } + uint64_t avx_sort = std::accumulate(runtimes1.end() - lastfew, + runtimes1.end(), + (uint64_t)0) + / lastfew; + + for (uint64_t ii = 0; ii < iters; ++ii) { + start = cycles_start(); + std::sort(sortedaar_bckup.begin(), + sortedaar_bckup.end(), + [](sorted_t a, sorted_t b) { + return a.key < b.key; + }); + end = cycles_end(); + runtimes2.emplace_back(end - start); + sortedaar_bckup = sortedaar; + } + uint64_t std_sort = std::accumulate(runtimes2.end() - lastfew, + runtimes2.end(), + (uint64_t)0) + / lastfew; + return std::make_tuple(avx_sort, std_sort); +} diff --git a/benchmarks/main.cpp b/benchmarks/main.cpp index 5340b881..b8cf95bb 100644 --- a/benchmarks/main.cpp +++ b/benchmarks/main.cpp @@ -22,7 +22,7 @@ template +void run_bench_kv(const std::string datatype) +{ + std::streamsize ss = std::cout.precision(); + std::cout << std::fixed; + std::cout << std::setprecision(1); + std::vector array_sizes = {10000, 100000, 1000000}; + for (auto size : array_sizes) { + std::vector keys; + std::vector values; + std::vector> sortedarr; + + if (datatype.find("kv_uniform") != std::string::npos) { + keys = get_uniform_rand_array(size); + } + else if (datatype.find("kv_reverse") != std::string::npos) { + for (int ii = 0; ii < size; ++ii) { + //arr.emplace_back((T)(size - ii)); + keys.emplace_back((K)(size - ii)); + } + } + else if (datatype.find("kv_ordered") != std::string::npos) { + for (int ii = 0; ii < size; ++ii) { + keys.emplace_back((ii)); + } + } + else if (datatype.find("kv_limited") != std::string::npos) { + keys = get_uniform_rand_array(size, (K)10, (K)0); + } + else { + std::cout << "Skipping unrecognized array type: " << datatype + << std::endl; + return; + } + values = get_uniform_rand_array(size); + for (size_t i = 0; i < keys.size(); i++) { + sorted_t tmp_s; + tmp_s.key = keys[i]; + tmp_s.value = values[i]; + sortedarr.emplace_back(tmp_s); + } + + auto out = bench_sort_kv(keys, values, sortedarr, 20, 10); + printLine(' ', + datatype, + typeid(K).name(), + sizeof(K), + size, + std::get<0>(out), + std::get<1>(out), + (float)std::get<1>(out) / std::get<0>(out)); + } + std::cout << std::setprecision(ss); +} void bench_all(const std::string datatype) { if (cpu_has_avx512bw()) { @@ -97,7 +151,15 @@ void bench_all(const std::string datatype) } } } +void bench_all_kv(const std::string datatype) +{ + if (cpu_has_avx512bw()) { + run_bench_kv(datatype); + run_bench_kv(datatype); + run_bench_kv(datatype); + } +} int main(/*int argc, char *argv[]*/) { printLine(' ', @@ -113,6 +175,11 @@ int main(/*int argc, char *argv[]*/) bench_all("reverse"); bench_all("ordered"); bench_all("limitedrange"); + + bench_all_kv("kv_uniform random"); + bench_all_kv("kv_reverse"); + bench_all_kv("kv_ordered"); + bench_all_kv("kv_limitedrange"); printLine('-', "", "", "", "", "", "", ""); return 0; } diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h new file mode 100644 index 00000000..32a4731e --- /dev/null +++ b/src/avx512-64bit-common.h @@ -0,0 +1,404 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX512_64BIT_COMMOM +#define AVX512_64BIT_COMMOM +#include "avx512-common-qsort.h" + +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + +template <> +struct zmm_vector { + using type_t = int64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT64; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT64; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = uint64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT64; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = double; + using zmm_t = __m512d; + using ymm_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITY; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITY; + } + static zmm_t zmm_max() + { + return _mm512_set1_pd(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_pd(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_pd(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_pd(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_pd(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_pd(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_pd(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_pd(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_pd(mem, x); + } +}; +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask8 loadmask = 0xFF; + while (arrsize > 0) { + if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } + __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); + __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); + arr += 8; + arrsize -= 8; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) +{ + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + arr[ii] = std::nan("1"); + nan_count -= 1; + } +} +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge( + zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), zmm), + 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); + zmm = cmp_merge( + zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), + 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + return zmm; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 8 + int64_t size = (right - left) / 8; + using zmm_t = typename vtype::zmm_t; + __m512i rand_index = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + // pivot will never be a nan, since there are no nan's! + zmm_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; +} + +#endif \ No newline at end of file diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp new file mode 100644 index 00000000..8140be97 --- /dev/null +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -0,0 +1,880 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Liu Zhuan + * Tang Xi + * ****************************************************************/ + +#ifndef AVX512_QSORT_64BIT_KV +#define AVX512_QSORT_64BIT_KV + +#include "avx512-common-keyvaluesort.h" + +template ::zmm_t> +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), + index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(rev_index, key_zmm), + index_zmm, + zmm_vector::permutexvar(rev_index, index_zmm), + 0xF0); + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), + 0xAA); + return key_zmm; +} +// Assumes zmm is bitonic and performs a recursive half cleaner +template ::zmm_t> +X86_SIMD_SORT_INLINE zmm_t +bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) +{ + + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), + index_zmm), + 0xF0); + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), + 0xCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), + 0xAA); + return key_zmm; +} +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, + zmm_t &key_zmm2, + index_type &index_zmm1, + index_type &index_zmm2) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); + index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); + + zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); + zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); + + index_type index_zmm3 = zmm_vector::mask_mov( + index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); + index_type index_zmm4 = zmm_vector::mask_mov( + index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); + + // 2) Recursive half cleaner for each + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + index_zmm1 = index_zmm3; + index_zmm2 = index_zmm4; +} +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, + index_type *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network + zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); + zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); + index_type index_zmm2r + = zmm_vector::permutexvar(rev_index, index_zmm[2]); + index_type index_zmm3r + = zmm_vector::permutexvar(rev_index, index_zmm[3]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); + + index_type index_zmm_t1 = zmm_vector::mask_mov( + index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_type index_zmm_t2 = zmm_vector::mask_mov( + index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); + + // 2) Recursive half clearer: 16 + zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t3 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t4 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); + zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); + + index_type index_zmm0 = zmm_vector::mask_mov( + index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_type index_zmm1 = zmm_vector::mask_mov( + index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_type index_zmm2 = zmm_vector::mask_mov( + index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_type index_zmm3 = zmm_vector::mask_mov( + index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + + index_zmm[0] = index_zmm0; + index_zmm[1] = index_zmm1; + index_zmm[2] = index_zmm2; + index_zmm[3] = index_zmm3; +} +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, + index_type *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); + zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); + zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); + zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); + index_type index_zmm4r + = zmm_vector::permutexvar(rev_index, index_zmm[4]); + index_type index_zmm5r + = zmm_vector::permutexvar(rev_index, index_zmm[5]); + index_type index_zmm6r + = zmm_vector::permutexvar(rev_index, index_zmm[6]); + index_type index_zmm7r + = zmm_vector::permutexvar(rev_index, index_zmm[7]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); + zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm5r); + zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm4r); + + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); + zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); + + index_type index_zmm_t1 = zmm_vector::mask_mov( + index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_type index_zmm_t2 = zmm_vector::mask_mov( + index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_type index_zmm_t3 = zmm_vector::mask_mov( + index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = zmm_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_type index_zmm_t4 = zmm_vector::mask_mov( + index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = zmm_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); + + zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t5 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t6 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t7 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t8 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; +} +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, + index_type *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); + zmm_t key_zmm9r = vtype::permutexvar(rev_index, key_zmm[9]); + zmm_t key_zmm10r = vtype::permutexvar(rev_index, key_zmm[10]); + zmm_t key_zmm11r = vtype::permutexvar(rev_index, key_zmm[11]); + zmm_t key_zmm12r = vtype::permutexvar(rev_index, key_zmm[12]); + zmm_t key_zmm13r = vtype::permutexvar(rev_index, key_zmm[13]); + zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); + zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); + + index_type index_zmm8r + = zmm_vector::permutexvar(rev_index, index_zmm[8]); + index_type index_zmm9r + = zmm_vector::permutexvar(rev_index, index_zmm[9]); + index_type index_zmm10r + = zmm_vector::permutexvar(rev_index, index_zmm[10]); + index_type index_zmm11r + = zmm_vector::permutexvar(rev_index, index_zmm[11]); + index_type index_zmm12r + = zmm_vector::permutexvar(rev_index, index_zmm[12]); + index_type index_zmm13r + = zmm_vector::permutexvar(rev_index, index_zmm[13]); + index_type index_zmm14r + = zmm_vector::permutexvar(rev_index, index_zmm[14]); + index_type index_zmm15r + = zmm_vector::permutexvar(rev_index, index_zmm[15]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); + zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); + zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm12r); + zmm_t key_zmm_t5 = vtype::min(key_zmm[4], key_zmm11r); + zmm_t key_zmm_t6 = vtype::min(key_zmm[5], key_zmm10r); + zmm_t key_zmm_t7 = vtype::min(key_zmm[6], key_zmm9r); + zmm_t key_zmm_t8 = vtype::min(key_zmm[7], key_zmm8r); + + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm15r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm14r); + zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm13r); + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm12r); + zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); + zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); + zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); + zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); + + index_type index_zmm_t1 = zmm_vector::mask_mov( + index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_type index_zmm_t2 = zmm_vector::mask_mov( + index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_type index_zmm_t3 = zmm_vector::mask_mov( + index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = zmm_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_type index_zmm_t4 = zmm_vector::mask_mov( + index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = zmm_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_type index_zmm_t5 = zmm_vector::mask_mov( + index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_type index_zmm_m5 = zmm_vector::mask_mov( + index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_type index_zmm_t6 = zmm_vector::mask_mov( + index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_type index_zmm_m6 = zmm_vector::mask_mov( + index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_type index_zmm_t7 = zmm_vector::mask_mov( + index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_type index_zmm_m7 = zmm_vector::mask_mov( + index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_type index_zmm_t8 = zmm_vector::mask_mov( + index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_type index_zmm_m8 = zmm_vector::mask_mov( + index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); + + zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); + zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); + zmm_t key_zmm_t11 = vtype::permutexvar(rev_index, key_zmm_m6); + zmm_t key_zmm_t12 = vtype::permutexvar(rev_index, key_zmm_m5); + zmm_t key_zmm_t13 = vtype::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t9 + = zmm_vector::permutexvar(rev_index, index_zmm_m8); + index_type index_zmm_t10 + = zmm_vector::permutexvar(rev_index, index_zmm_m7); + index_type index_zmm_t11 + = zmm_vector::permutexvar(rev_index, index_zmm_m6); + index_type index_zmm_t12 + = zmm_vector::permutexvar(rev_index, index_zmm_m5); + index_type index_zmm_t13 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t14 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t15 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t16 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + // + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; + index_zmm[8] = index_zmm_t9; + index_zmm[9] = index_zmm_t10; + index_zmm[10] = index_zmm_t11; + index_zmm[11] = index_zmm_t12; + index_zmm[12] = index_zmm_t13; + index_zmm[13] = index_zmm_t14; + index_zmm[14] = index_zmm_t15; + index_zmm[15] = index_zmm_t16; +} +template +X86_SIMD_SORT_INLINE void +sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; + typename vtype::zmm_t key_zmm + = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); + + zmm_vector::zmm_t index_zmm = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes); + vtype::mask_storeu( + keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); + zmm_vector::mask_storeu(indexes, load_mask, index_zmm); +} + +template +X86_SIMD_SORT_INLINE void +sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 8) { + sort_8_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using index_type = zmm_vector::zmm_t; + + typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + + zmm_t key_zmm1 = vtype::loadu(keys); + zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); + + index_type index_zmm1 = zmm_vector::loadu(indexes); + index_type index_zmm2 = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes + 8); + + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( + key_zmm1, key_zmm2, index_zmm1, index_zmm2); + + zmm_vector::storeu(indexes, index_zmm1); + zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); + + vtype::storeu(keys, key_zmm1); + vtype::mask_storeu(keys + 8, load_mask, key_zmm2); +} + +template +X86_SIMD_SORT_INLINE void +sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 16) { + sort_16_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + using index_type = zmm_vector::zmm_t; + zmm_t key_zmm[4]; + index_type index_zmm[4]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); + key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); + + index_zmm[2] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 24); + + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); + zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::mask_storeu(keys + 16, load_mask1, key_zmm[2]); + vtype::mask_storeu(keys + 24, load_mask2, key_zmm[3]); +} + +template +X86_SIMD_SORT_INLINE void +sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 32) { + sort_32_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + using index_type = zmm_vector::zmm_t; + zmm_t key_zmm[8]; + index_type index_zmm[8]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[2] = vtype::loadu(keys + 16); + key_zmm[3] = vtype::loadu(keys + 24); + + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[2] = zmm_vector::loadu(indexes + 16); + index_zmm[3] = zmm_vector::loadu(indexes + 24); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + // N-32 >= 1 + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + key_zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 32); + key_zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 40); + key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); + key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); + + index_zmm[4] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::storeu(indexes + 16, index_zmm[2]); + zmm_vector::storeu(indexes + 24, index_zmm[3]); + zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); + zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); + zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); + zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::storeu(keys + 16, key_zmm[2]); + vtype::storeu(keys + 24, key_zmm[3]); + vtype::mask_storeu(keys + 32, load_mask1, key_zmm[4]); + vtype::mask_storeu(keys + 40, load_mask2, key_zmm[5]); + vtype::mask_storeu(keys + 48, load_mask3, key_zmm[6]); + vtype::mask_storeu(keys + 56, load_mask4, key_zmm[7]); +} + +template +X86_SIMD_SORT_INLINE void +sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 64) { + sort_64_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using index_type = zmm_vector::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t key_zmm[16]; + index_type index_zmm[16]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[2] = vtype::loadu(keys + 16); + key_zmm[3] = vtype::loadu(keys + 24); + key_zmm[4] = vtype::loadu(keys + 32); + key_zmm[5] = vtype::loadu(keys + 40); + key_zmm[6] = vtype::loadu(keys + 48); + key_zmm[7] = vtype::loadu(keys + 56); + + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[2] = zmm_vector::loadu(indexes + 16); + index_zmm[3] = zmm_vector::loadu(indexes + 24); + index_zmm[4] = zmm_vector::loadu(indexes + 32); + index_zmm[5] = zmm_vector::loadu(indexes + 40); + index_zmm[6] = zmm_vector::loadu(indexes + 48); + index_zmm[7] = zmm_vector::loadu(indexes + 56); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + if (N != 128) { + uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + } + key_zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 64); + key_zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 72); + key_zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 80); + key_zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 88); + key_zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, keys + 96); + key_zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, keys + 104); + key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); + key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); + + index_zmm[8] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); + bitonic_merge_two_zmm_64bit( + key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); + bitonic_merge_two_zmm_64bit( + key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); + bitonic_merge_two_zmm_64bit( + key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::storeu(indexes + 16, index_zmm[2]); + zmm_vector::storeu(indexes + 24, index_zmm[3]); + zmm_vector::storeu(indexes + 32, index_zmm[4]); + zmm_vector::storeu(indexes + 40, index_zmm[5]); + zmm_vector::storeu(indexes + 48, index_zmm[6]); + zmm_vector::storeu(indexes + 56, index_zmm[7]); + zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); + zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); + zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); + zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); + zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); + zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); + zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); + zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::storeu(keys + 16, key_zmm[2]); + vtype::storeu(keys + 24, key_zmm[3]); + vtype::storeu(keys + 32, key_zmm[4]); + vtype::storeu(keys + 40, key_zmm[5]); + vtype::storeu(keys + 48, key_zmm[6]); + vtype::storeu(keys + 56, key_zmm[7]); + vtype::mask_storeu(keys + 64, load_mask1, key_zmm[8]); + vtype::mask_storeu(keys + 72, load_mask2, key_zmm[9]); + vtype::mask_storeu(keys + 80, load_mask3, key_zmm[10]); + vtype::mask_storeu(keys + 88, load_mask4, key_zmm[11]); + vtype::mask_storeu(keys + 96, load_mask5, key_zmm[12]); + vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); + vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); + vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); +} + +template +void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) +{ + int64_t i = idx; + while (true) { + int64_t j = 2 * i + 1; + if (j >= size || j < 0) { break; } + int k = j + 1; + if (k < size && keys[j] < keys[k]) { j = k; } + if (keys[j] < keys[i]) { break; } + std::swap(keys[i], keys[j]); + std::swap(indexes[i], indexes[j]); + i = j; + } +} +template +void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) +{ + for (int64_t i = size / 2 - 1; i >= 0; i--) { + heapify(keys, indexes, i, size); + } + for (int64_t i = size - 1; i > 0; i--) { + std::swap(keys[0], keys[i]); + std::swap(indexes[0], indexes[i]); + heapify(keys, indexes, 0, i); + } +} + +template +struct sortkv_t { + T key; + uint64_t value; +}; +template +void qsort_64bit_(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + //std::sort(keys+left,keys+right+1); + heap_sort(keys + left, indexes + left, right - left + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + + sort_128_64bit( + keys + left, indexes + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_64bit(keys, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512( + keys, indexes, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) { + qsort_64bit_( + keys, indexes, left, pivot_index - 1, max_iters - 1); + } + if (pivot != biggest) { + qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); + } +} + +template <> +void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) +{ + if (arrsize > 1) { + qsort_64bit_, int64_t>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_qsort_kv(uint64_t *keys, + uint64_t *indexes, + int64_t arrsize) +{ + if (arrsize > 1) { + qsort_64bit_, uint64_t>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) +{ + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(keys, arrsize); + qsort_64bit_, double>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(keys, arrsize, nan_count); + } +} +#endif // AVX512_QSORT_64BIT_KV diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 7e8db546..62000549 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -7,7 +7,7 @@ #ifndef AVX512_QSORT_64BIT #define AVX512_QSORT_64BIT -#include "avx512-common-qsort.h" +#include "avx512-64bit-common.h" /* * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic @@ -15,341 +15,6 @@ * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) */ // ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - -template <> -struct zmm_vector { - using type_t = int64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT64; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT64; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = uint64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT64; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = double; - using zmm_t = __m512d; - using ymm_t = __m512d; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITY; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITY; - } - static zmm_t zmm_max() - { - return _mm512_set1_pd(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_pd(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_pd(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_pd(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_pd(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_pd(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_pd(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_pd(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_pd(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_pd(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_pd(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_pd(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_pd(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_pd(mem, x); - } -}; - -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), zmm), - 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), - 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - return zmm; -} // Assumes zmm is bitonic and performs a recursive half cleaner template @@ -371,7 +36,6 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) zmm, vtype::template shuffle(zmm), 0xAA); return zmm; } - // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner template X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm_t &zmm2) @@ -385,7 +49,6 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm_t &zmm2) zmm1 = bitonic_merge_zmm_64bit(zmm3); zmm2 = bitonic_merge_zmm_64bit(zmm4); } - // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner template @@ -409,7 +72,6 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) zmm[2] = bitonic_merge_zmm_64bit(zmm2); zmm[3] = bitonic_merge_zmm_64bit(zmm3); } - template X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) { @@ -443,7 +105,6 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); } - template X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) { @@ -717,28 +378,6 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); } -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - const int64_t left, - const int64_t right) -{ - // median of 8 - int64_t size = (right - left) / 8; - using zmm_t = typename vtype::zmm_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); - // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; -} - template static void qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) @@ -769,31 +408,6 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) qsort_64bit_(arr, pivot_index, right, max_iters - 1); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); - __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); - arr += 8; - arrsize -= 8; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; - } -} - template <> void avx512_qsort(int64_t *arr, int64_t arrsize) { diff --git a/src/avx512-common-keyvaluesort.h b/src/avx512-common-keyvaluesort.h new file mode 100644 index 00000000..f2821072 --- /dev/null +++ b/src/avx512-common-keyvaluesort.h @@ -0,0 +1,240 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * Copyright (C) 2021 Serge Sans Paille + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Liu Zhuan + * Tang Xi + * ****************************************************************/ + +#ifndef AVX512_QSORT_COMMON_KV +#define AVX512_QSORT_COMMON_KV + +/* + * Quicksort using AVX-512. The ideas and code are based on these two research + * papers [1] and [2]. On a high level, the idea is to vectorize quicksort + * partitioning using AVX-512 compressstore instructions. If the array size is + * < 128, then use Bitonic sorting network implemented on 512-bit registers. + * The precise network definitions depend on the dtype and are defined in + * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and + * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting + * network. The core implementations of the vectorized qsort functions + * avx512_qsort(T*, int64_t) are modified versions of avx2 quicksort + * presented in the paper [2] and source code associated with that paper [3]. + * + * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types + * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ + * + * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel + * Skylake https://arxiv.org/pdf/1704.08579.pdf + * + * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: MIT + * + * [4] http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 + * + */ + +#include "avx512-64bit-common.h" + +template +void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); + +using index_t = __m512i; + +template > +static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) +{ + mm_t key_t1 = vtype::min(key1, key2); + mm_t key_t2 = vtype::max(key1, key2); + + index_t index_t1 + = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); + index_t index_t2 + = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); + + key1 = key_t1; + key2 = key_t2; + index1 = index_t1; + index2 = index_t2; +} +template > +static inline zmm_t cmp_merge(zmm_t in1, + zmm_t in2, + index_t &indexes1, + index_t indexes2, + opmask_t mask) +{ + zmm_t tmp_keys = cmp_merge(in1, in2, mask); + indexes1 = index_type::mask_mov( + indexes2, vtype::eq(tmp_keys, in1), indexes1); + return tmp_keys; // 0 -> min, 1 -> max +} +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template > +static inline int32_t partition_vec(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + const zmm_t keys_vec, + const index_t indexes_vec, + const zmm_t pivot_vec, + zmm_t *smallest_vec, + zmm_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(keys_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + vtype::mask_compressstoreu( + keys + left, vtype::knot_opmask(gt_mask), keys_vec); + vtype::mask_compressstoreu( + keys + right - amount_gt_pivot, gt_mask, keys_vec); + index_type::mask_compressstoreu( + indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); + index_type::mask_compressstoreu( + indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + *smallest_vec = vtype::min(keys_vec, *smallest_vec); + *biggest_vec = vtype::max(keys_vec, *biggest_vec); + return amount_gt_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template > +static inline int64_t partition_avx512(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, keys[left]); + *biggest = std::max(*biggest, keys[left]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + std::swap(indexes[left], indexes[right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + zmm_t keys_vec = vtype::loadu(keys + left); + int32_t amount_gt_pivot; + + index_t indexes_vec = index_type::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + zmm_t keys_vec_left = vtype::loadu(keys + left); + zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); + index_t indexes_vec_left; + index_t indexes_vec_right; + indexes_vec_left = index_type::loadu(indexes + left); + indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); + + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + zmm_t keys_vec; + index_t indexes_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + keys_vec = vtype::loadu(keys + right); + indexes_vec = index_type::loadu(indexes + right); + } + else { + keys_vec = vtype::loadu(keys + left); + indexes_vec = index_type::loadu(indexes + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot; + + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot; + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + l_store + vtype::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} +#endif // AVX512_QSORT_COMMON_KV diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index d1f6cbb4..1816baf3 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -103,7 +103,6 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } - template @@ -113,7 +112,6 @@ static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) zmm_t max = vtype::max(in2, in1); return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max } - /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -138,7 +136,6 @@ static inline int32_t partition_vec(type_t *arr, *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; } - /* * Parition an array based on the pivot and returns the index of the * last element that is less than equal to the pivot. diff --git a/tests/meson.build b/tests/meson.build index 7d51ba26..40cd4685 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -1,19 +1,15 @@ libtests = [] -if cc.has_argument('-march=icelake-client') - libtests += static_library( - 'tests_', - files( - 'test_all.cpp', - ), - dependencies : gtest_dep, - include_directories : [ - src, - utils, - ], - cpp_args : [ - '-O3', - '-march=icelake-client', - ], - ) -endif + if cc.has_argument('-march=icelake-client') libtests + += static_library('tests_', files('test_all.cpp', ), dependencies + : gtest_dep, include_directories + : + [ + src, + utils, + ], + cpp_args + : [ + '-O3', + '-march=icelake-client', + ], ) endif diff --git a/tests/test_all.cpp b/tests/test_all.cpp index 6d82a35b..ff65fff0 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -5,6 +5,7 @@ #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-keyvaluesort.hpp" #include "avx512-64bit-qsort.hpp" #include "cpuinfo.h" #include "rand_array.h" @@ -56,3 +57,60 @@ using Types = testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_sort, Types); + +template +struct sorted_t { + K key; + K value; +}; +template +bool compare(sorted_t a, sorted_t b) +{ + return a.key == b.key ? a.value < b.value : a.key < b.key; +} + +template +class TestKeyValueSort : public ::testing::Test { +}; + +TYPED_TEST_SUITE_P(TestKeyValueSort); + +TYPED_TEST_P(TestKeyValueSort, KeyValueSort) +{ + std::vector keysizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + keysizes.push_back((TypeParam)ii); + } + std::vector keys; + std::vector values; + std::vector> sortedarr; + + for (size_t ii = 0; ii < keysizes.size(); ++ii) { + /* Random array */ + keys = get_uniform_rand_array_key(keysizes[ii]); + values = get_uniform_rand_array(keysizes[ii]); + for (size_t i = 0; i < keys.size(); i++) { + sorted_t tmp_s; + tmp_s.key = keys[i]; + tmp_s.value = values[i]; + sortedarr.emplace_back(tmp_s); + } + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), + sortedarr.end(), + compare); + avx512_qsort_kv(keys.data(), values.data(), keys.size()); + for (size_t i = 0; i < keys.size(); i++) { + ASSERT_EQ(keys[i], sortedarr[i].key); + ASSERT_EQ(values[i], sortedarr[i].value); + } + keys.clear(); + values.clear(); + sortedarr.clear(); + } +} + +REGISTER_TYPED_TEST_SUITE_P(TestKeyValueSort, KeyValueSort); + +using TypesKv = testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefixKv, TestKeyValueSort, TypesKv); \ No newline at end of file diff --git a/utils/rand_array.h b/utils/rand_array.h index 0842a0b4..804226f7 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -3,6 +3,7 @@ * * SPDX-License-Identifier: BSD-3-Clause * *******************************************/ +#include #include #include #include @@ -33,10 +34,59 @@ static std::vector get_uniform_rand_array( { std::random_device rd; std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(min, max); + std::uniform_real_distribution dis(min, max); std::vector arr; for (int64_t ii = 0; ii < arrsize; ++ii) { arr.emplace_back(dis(gen)); } return arr; } +template +static std::vector get_uniform_rand_array_key( + int64_t arrsize, + T max = std::numeric_limits::max(), + T min = std::numeric_limits::min(), + typename std::enable_if::value>::type * = 0) +{ + std::vector arr; + std::random_device r; + std::default_random_engine e1(r()); + std::uniform_int_distribution uniform_dist(min, max); + for (int64_t ii = 0; ii < arrsize; ++ii) { + + while (true) { + T tmp = uniform_dist(e1); + auto iter = std::find(arr.begin(), arr.end(), tmp); + if (iter == arr.end()) { + arr.emplace_back(tmp); + break; + } + } + } + return arr; +} + +template +static std::vector get_uniform_rand_array_key( + int64_t arrsize, + T max = std::numeric_limits::max(), + T min = std::numeric_limits::min(), + typename std::enable_if::value>::type * = 0) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(min, max); + std::vector arr; + for (int64_t ii = 0; ii < arrsize; ++ii) { + + while (true) { + T tmp = dis(gen); + auto iter = std::find(arr.begin(), arr.end(), tmp); + if (iter == arr.end()) { + arr.emplace_back(tmp); + break; + } + } + } + return arr; +} \ No newline at end of file