From 8335a0bb1ab7f19ede2b18bf5b350246d84a3664 Mon Sep 17 00:00:00 2001 From: ruclz Date: Mon, 24 Oct 2022 15:22:12 +0800 Subject: [PATCH 01/16] add key-value file! --- src/avx512-qsort-key-value.hpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/avx512-qsort-key-value.hpp diff --git a/src/avx512-qsort-key-value.hpp b/src/avx512-qsort-key-value.hpp new file mode 100644 index 00000000..e69de29b From 156142a94763471a389fa03e6cedcfdf361c7cea Mon Sep 17 00:00:00 2001 From: ruclz Date: Thu, 17 Nov 2022 14:04:48 +0800 Subject: [PATCH 02/16] Completed the key-value sort with uint64_t. --- .vscode/settings.json | 57 +++ Makefile | 4 +- benchmarks/bench.hpp | 2 +- src/avx512-16bit-qsort.hpp | 4 +- src/avx512-32bit-qsort.hpp | 6 +- src/avx512-64bit-qsort.hpp | 859 +++++++++++++++++++++++++++------ src/avx512-common-qsort.h | 219 ++++++++- src/avx512-qsort-key-value.hpp | 0 tests/test_all.cpp | 2 +- 9 files changed, 992 insertions(+), 161 deletions(-) create mode 100644 .vscode/settings.json delete mode 100644 src/avx512-qsort-key-value.hpp diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..5abdc8c3 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,57 @@ +{ + "files.associations": { + "*.tcc": "cpp", + "functional": "cpp", + "string_view": "cpp", + "random": "cpp", + "istream": "cpp", + "limits": "cpp", + "algorithm": "cpp", + "bit": "cpp", + "numeric": "cpp", + "cctype": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "array": "cpp", + "atomic": "cpp", + "cstdint": "cpp", + "deque": "cpp", + "unordered_map": "cpp", + "vector": "cpp", + "exception": "cpp", + "iterator": "cpp", + "memory": "cpp", + "memory_resource": "cpp", + "optional": "cpp", + "string": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "utility": "cpp", + "fstream": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "new": "cpp", + "ostream": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "cinttypes": "cpp", + "typeinfo": "cpp", + "compare": "cpp", + "concepts": "cpp", + "numbers": "cpp", + "map": "cpp", + "set": "cpp" + } +} \ No newline at end of file diff --git a/Makefile b/Makefile index 938dbe5b..04bf5d78 100644 --- a/Makefile +++ b/Makefile @@ -15,10 +15,10 @@ LD_FLAGS = -L /usr/local/lib -l $(GTEST_LIB) -l pthread all : test bench $(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS) - $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -c $< -o $@ + $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -l $(GTEST_LIB) -c $< -o $@ test: $(TESTDIR)/main.cpp $(TESTOBJS) $(SRCS) - $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe + $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe bench: $(BENCHDIR)/main.cpp $(SRCS) $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe diff --git a/benchmarks/bench.hpp b/benchmarks/bench.hpp index 0837a6ca..48fca6cd 100644 --- a/benchmarks/bench.hpp +++ b/benchmarks/bench.hpp @@ -49,7 +49,7 @@ std::tuple bench_sort(const std::vector arr, uint64_t start(0), end(0); for (uint64_t ii = 0; ii < iters; ++ii) { start = cycles_start(); - avx512_qsort(arr_bckup.data(), arr_bckup.size()); + avx512_qsort(arr_bckup.data(), NULL, arr_bckup.size()); end = cycles_end(); runtimes1.emplace_back(end - start); arr_bckup = arr; diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 39b32515..cef86a2f 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -484,7 +484,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template <> -void avx512_qsort(int16_t *arr, int64_t arrsize) +inline void avx512_qsort(int16_t *arr,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, int16_t>( @@ -493,7 +493,7 @@ void avx512_qsort(int16_t *arr, int64_t arrsize) } template <> -void avx512_qsort(uint16_t *arr, int64_t arrsize) +inline void avx512_qsort(uint16_t *arr,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, uint16_t>( diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index efc4a4f2..9b8bcea5 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -684,7 +684,7 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qsort(int32_t *arr, int64_t arrsize) +inline void avx512_qsort(int32_t *arr,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, int32_t>( @@ -693,7 +693,7 @@ void avx512_qsort(int32_t *arr, int64_t arrsize) } template <> -void avx512_qsort(uint32_t *arr, int64_t arrsize) +inline void avx512_qsort(uint32_t *arr,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, uint32_t>( @@ -702,7 +702,7 @@ void avx512_qsort(uint32_t *arr, int64_t arrsize) } template <> -void avx512_qsort(float *arr, int64_t arrsize) +inline void avx512_qsort(float *arr,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(arr, arrsize); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index b2bd0ed4..6176adf1 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -56,6 +56,10 @@ struct zmm_vector { { return _knot_mask8(x); } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); @@ -162,6 +166,10 @@ struct zmm_vector { { return _knot_mask8(x); } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); + } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); @@ -258,6 +266,10 @@ struct zmm_vector { { return _knot_mask8(x); } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OS); + } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); @@ -329,6 +341,7 @@ struct zmm_vector { template X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t zmm) { + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); @@ -348,6 +361,36 @@ X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t zmm) return zmm; } +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t key_zmm,index_t &index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + key_zmm = cmp_merge( + key_zmm, vtype::template shuffle(key_zmm), + index_zmm, zmm_vector::template shuffle(index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm,vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + index_zmm,zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, vtype::template shuffle(key_zmm), + index_zmm, zmm_vector::template shuffle(index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, vtype::permutexvar(rev_index, key_zmm), + index_zmm,zmm_vector::permutexvar(rev_index, index_zmm), + 0xF0); + key_zmm = cmp_merge( + key_zmm,vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm,zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, vtype::template shuffle(key_zmm), + index_zmm, zmm_vector::template shuffle(index_zmm), + 0xAA); + return key_zmm; +} // Assumes zmm is bitonic and performs a recursive half cleaner template X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) @@ -368,7 +411,32 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) zmm, vtype::template shuffle(zmm), 0xAA); return zmm; } +// Assumes zmm is bitonic and performs a recursive half cleaner +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm,zmm_vector::zmm_t &index_zmm) +{ + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + 0xF0); + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + 0xCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, vtype::template shuffle(key_zmm), + index_zmm, zmm_vector::template shuffle(index_zmm), + 0xAA); + return key_zmm; +} // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner template X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, @@ -383,7 +451,30 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm1 = bitonic_merge_zmm_64bit(zmm3); zmm2 = bitonic_merge_zmm_64bit(zmm4); } +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, + zmm_t &key_zmm2, index_t &index_zmm1, + index_t &index_zmm2) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); + index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); + + zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); + zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); + + index_t index_zmm3=zmm_vector::mask_mov(index_zmm2,vtype::eq(key_zmm3,key_zmm1),index_zmm1); + index_t index_zmm4=zmm_vector::mask_mov(index_zmm1,vtype::eq(key_zmm3,key_zmm1),index_zmm2); + + // 2) Recursive half cleaner for each + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3,index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4,index_zmm4); + index_zmm1=index_zmm3; + index_zmm2=index_zmm4; +} // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner template @@ -407,7 +498,55 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) zmm[2] = bitonic_merge_zmm_64bit(zmm2); zmm[3] = bitonic_merge_zmm_64bit(zmm3); } +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network + zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); + zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); + index_t index_zmm2r = zmm_vector::permutexvar(rev_index, index_zmm[2]); + index_t index_zmm3r = zmm_vector::permutexvar(rev_index, index_zmm[3]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); + + index_t index_zmm_t1=zmm_vector::mask_mov(index_zmm3r,vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); + index_t index_zmm_m1=zmm_vector::mask_mov(index_zmm[0],vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm3r); + index_t index_zmm_t2=zmm_vector::mask_mov(index_zmm2r,vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); + index_t index_zmm_m2=zmm_vector::mask_mov(index_zmm[1],vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm2r); + + + // 2) Recursive half clearer: 16 + zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); + index_t index_zmm_t3 = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t4 = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); + zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); + + index_t index_zmm0 = zmm_vector::mask_mov(index_zmm_t2,vtype::eq(key_zmm0,key_zmm_t1),index_zmm_t1); + index_t index_zmm1 = zmm_vector::mask_mov(index_zmm_t1,vtype::eq(key_zmm0,key_zmm_t1),index_zmm_t2); + index_t index_zmm2 = zmm_vector::mask_mov(index_zmm_t4,vtype::eq(key_zmm2,key_zmm_t3),index_zmm_t3); + index_t index_zmm3 = zmm_vector::mask_mov(index_zmm_t3,vtype::eq(key_zmm2,key_zmm_t3),index_zmm_t4); + + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0,index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1,index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2,index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3,index_zmm3); + + index_zmm[0]=index_zmm0; + index_zmm[1]=index_zmm1; + index_zmm[2]=index_zmm2; + index_zmm[3]=index_zmm3; +} template X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) { @@ -441,7 +580,79 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); } +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); + zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); + zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); + zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); + index_t index_zmm4r = zmm_vector::permutexvar(rev_index, index_zmm[4]); + index_t index_zmm5r = zmm_vector::permutexvar(rev_index, index_zmm[5]); + index_t index_zmm6r = zmm_vector::permutexvar(rev_index, index_zmm[6]); + index_t index_zmm7r = zmm_vector::permutexvar(rev_index, index_zmm[7]); + + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); + zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm5r); + zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm4r); + + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); + zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); + + index_t index_zmm_t1= zmm_vector::mask_mov(index_zmm7r, vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); + index_t index_zmm_m1= zmm_vector::mask_mov(index_zmm[0], vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm7r); + index_t index_zmm_t2= zmm_vector::mask_mov(index_zmm6r, vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); + index_t index_zmm_m2= zmm_vector::mask_mov(index_zmm[1], vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm6r); + index_t index_zmm_t3= zmm_vector::mask_mov(index_zmm5r, vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm[2]); + index_t index_zmm_m3= zmm_vector::mask_mov(index_zmm[2], vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm5r); + index_t index_zmm_t4= zmm_vector::mask_mov(index_zmm4r, vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm[3]); + index_t index_zmm_m4= zmm_vector::mask_mov(index_zmm[3], vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm4r); + + + + zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); + index_t index_zmm_t5 = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t6 = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t7 = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t8 = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + + COEX(key_zmm_t1, key_zmm_t3,index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4,index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7,index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8,index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2,index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4,index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6,index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8,index_zmm_t7, index_zmm_t8); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1,index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2,index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3,index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4,index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5,index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6,index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7,index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8,index_zmm_t8); + + index_zmm[0]=index_zmm_t1; + index_zmm[1]=index_zmm_t2; + index_zmm[2]=index_zmm_t3; + index_zmm[3]=index_zmm_t4; + index_zmm[4]=index_zmm_t5; + index_zmm[5]=index_zmm_t6; + index_zmm[6]=index_zmm_t7; + index_zmm[7]=index_zmm_t8; + +} template X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) { @@ -515,83 +726,279 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); } +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); + zmm_t key_zmm9r = vtype::permutexvar(rev_index, key_zmm[9]); + zmm_t key_zmm10r = vtype::permutexvar(rev_index, key_zmm[10]); + zmm_t key_zmm11r = vtype::permutexvar(rev_index, key_zmm[11]); + zmm_t key_zmm12r = vtype::permutexvar(rev_index, key_zmm[12]); + zmm_t key_zmm13r = vtype::permutexvar(rev_index, key_zmm[13]); + zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); + zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); + + index_t index_zmm8r = zmm_vector::permutexvar(rev_index, index_zmm[8]); + index_t index_zmm9r = zmm_vector::permutexvar(rev_index, index_zmm[9]); + index_t index_zmm10r = zmm_vector::permutexvar(rev_index, index_zmm[10]); + index_t index_zmm11r = zmm_vector::permutexvar(rev_index, index_zmm[11]); + index_t index_zmm12r = zmm_vector::permutexvar(rev_index, index_zmm[12]); + index_t index_zmm13r = zmm_vector::permutexvar(rev_index, index_zmm[13]); + index_t index_zmm14r = zmm_vector::permutexvar(rev_index, index_zmm[14]); + index_t index_zmm15r = zmm_vector::permutexvar(rev_index, index_zmm[15]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); + zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); + zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm12r); + zmm_t key_zmm_t5 = vtype::min(key_zmm[4], key_zmm11r); + zmm_t key_zmm_t6 = vtype::min(key_zmm[5], key_zmm10r); + zmm_t key_zmm_t7 = vtype::min(key_zmm[6], key_zmm9r); + zmm_t key_zmm_t8 = vtype::min(key_zmm[7], key_zmm8r); + + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm15r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm14r); + zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm13r); + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm12r); + zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); + zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); + zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); + zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); + + index_t index_zmm_t1=zmm_vector::mask_mov(index_zmm15r,vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); + index_t index_zmm_m1=zmm_vector::mask_mov(index_zmm[0],vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm15r); + index_t index_zmm_t2=zmm_vector::mask_mov(index_zmm14r,vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); + index_t index_zmm_m2=zmm_vector::mask_mov(index_zmm[1],vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm14r); + index_t index_zmm_t3=zmm_vector::mask_mov(index_zmm13r,vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm[2]); + index_t index_zmm_m3=zmm_vector::mask_mov(index_zmm[2],vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm13r); + index_t index_zmm_t4=zmm_vector::mask_mov(index_zmm12r,vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm[3]); + index_t index_zmm_m4=zmm_vector::mask_mov(index_zmm[3],vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm12r); + + index_t index_zmm_t5=zmm_vector::mask_mov(index_zmm11r,vtype::eq(key_zmm_t5,key_zmm[4]),index_zmm[4]); + index_t index_zmm_m5=zmm_vector::mask_mov(index_zmm[4],vtype::eq(key_zmm_t5,key_zmm[4]),index_zmm11r); + index_t index_zmm_t6=zmm_vector::mask_mov(index_zmm10r,vtype::eq(key_zmm_t6,key_zmm[5]),index_zmm[5]); + index_t index_zmm_m6=zmm_vector::mask_mov(index_zmm[5],vtype::eq(key_zmm_t6,key_zmm[5]),index_zmm10r); + index_t index_zmm_t7=zmm_vector::mask_mov(index_zmm9r,vtype::eq(key_zmm_t7,key_zmm[6]),index_zmm[6]); + index_t index_zmm_m7=zmm_vector::mask_mov(index_zmm[6],vtype::eq(key_zmm_t7,key_zmm[6]),index_zmm9r); + index_t index_zmm_t8=zmm_vector::mask_mov(index_zmm8r,vtype::eq(key_zmm_t8,key_zmm[7]),index_zmm[7]); + index_t index_zmm_m8=zmm_vector::mask_mov(index_zmm[7],vtype::eq(key_zmm_t8,key_zmm[7]),index_zmm8r); + + zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); + zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); + zmm_t key_zmm_t11 = vtype::permutexvar(rev_index, key_zmm_m6); + zmm_t key_zmm_t12 = vtype::permutexvar(rev_index, key_zmm_m5); + zmm_t key_zmm_t13 = vtype::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); + index_t index_zmm_t9 = zmm_vector::permutexvar(rev_index, index_zmm_m8); + index_t index_zmm_t10 = zmm_vector::permutexvar(rev_index, index_zmm_m7); + index_t index_zmm_t11 = zmm_vector::permutexvar(rev_index, index_zmm_m6); + index_t index_zmm_t12 = zmm_vector::permutexvar(rev_index, index_zmm_m5); + index_t index_zmm_t13 = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t14 = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t15 = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t16 = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5,index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6,index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7,index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8,index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13,index_zmm_t9, index_zmm_t13); + COEX(key_zmm_t10, key_zmm_t14,index_zmm_t10, index_zmm_t14); + COEX(key_zmm_t11, key_zmm_t15,index_zmm_t11, index_zmm_t15); + COEX(key_zmm_t12, key_zmm_t16,index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3,index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4,index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7,index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8,index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11,index_zmm_t9, index_zmm_t11); + COEX(key_zmm_t10, key_zmm_t12,index_zmm_t10, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t15,index_zmm_t13, index_zmm_t15); + COEX(key_zmm_t14, key_zmm_t16,index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2,index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4,index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6,index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8,index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10,index_zmm_t9, index_zmm_t10); + COEX(key_zmm_t11, key_zmm_t12,index_zmm_t11, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t14,index_zmm_t13, index_zmm_t14); + COEX(key_zmm_t15, key_zmm_t16,index_zmm_t15, index_zmm_t16); + // + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1,index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2,index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3,index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4,index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5,index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6,index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7,index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8,index_zmm_t8); + key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9,index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10,index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11,index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12,index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13,index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14,index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15,index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16,index_zmm_t16); + + index_zmm[0]=index_zmm_t1; + index_zmm[1]=index_zmm_t2; + index_zmm[2]=index_zmm_t3; + index_zmm[3]=index_zmm_t4; + index_zmm[4]=index_zmm_t5; + index_zmm[5]=index_zmm_t6; + index_zmm[6]=index_zmm_t7; + index_zmm[7]=index_zmm_t8; + index_zmm[8]=index_zmm_t9; + index_zmm[9]=index_zmm_t10; + index_zmm[10]=index_zmm_t11; + index_zmm[11]=index_zmm_t12; + index_zmm[12]=index_zmm_t13; + index_zmm[13]=index_zmm_t14; + index_zmm[14]=index_zmm_t15; + index_zmm[15]=index_zmm_t16; +} template -X86_SIMD_SORT_FORCEINLINE void sort_8_64bit(type_t *arr, int32_t N) +X86_SIMD_SORT_FORCEINLINE void sort_8_64bit(type_t *keys,uint64_t *indexes, int32_t N) { typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype::zmm_t zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); - vtype::mask_storeu(arr, load_mask, sort_zmm_64bit(zmm)); + typename vtype::zmm_t key_zmm + = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); + if(indexes){ + zmm_vector::zmm_t index_zmm + = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask, indexes); + vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm,index_zmm)); + zmm_vector::mask_storeu(indexes, load_mask,index_zmm); + }else{ + vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm)); + } + } template -X86_SIMD_SORT_FORCEINLINE void sort_16_64bit(type_t *arr, int32_t N) +X86_SIMD_SORT_FORCEINLINE void sort_16_64bit(type_t *keys,uint64_t *indexes, int32_t N) { if (N <= 8) { - sort_8_64bit(arr, N); + sort_8_64bit(keys,indexes, N); return; } using zmm_t = typename vtype::zmm_t; - zmm_t zmm1 = vtype::loadu(arr); + using index_t = zmm_vector::zmm_t; + zmm_t key_zmm1 = vtype::loadu(keys); typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 8); - zmm1 = sort_zmm_64bit(zmm1); - zmm2 = sort_zmm_64bit(zmm2); - bitonic_merge_two_zmm_64bit(zmm1, zmm2); - vtype::storeu(arr, zmm1); - vtype::mask_storeu(arr + 8, load_mask, zmm2); + zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); + + if(indexes){ + index_t index_zmm1 = zmm_vector::loadu(indexes); + index_t index_zmm2 = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask, indexes + 8); + key_zmm1 = sort_zmm_64bit(key_zmm1,index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2,index_zmm2); + bitonic_merge_two_zmm_64bit(key_zmm1,key_zmm2,index_zmm1,index_zmm2); + zmm_vector::storeu(indexes, index_zmm1); + zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); + }else{ + key_zmm1 = sort_zmm_64bit(key_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2); + bitonic_merge_two_zmm_64bit(key_zmm1,key_zmm2); + } + + vtype::storeu(keys, key_zmm1); + vtype::mask_storeu(keys + 8, load_mask, key_zmm2); } template -X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *arr, int32_t N) +X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int32_t N) { if (N <= 16) { - sort_16_64bit(arr, N); + sort_16_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; using opmask_t = typename vtype::opmask_t; - zmm_t zmm[4]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); + using index_t = zmm_vector::zmm_t; + zmm_t key_zmm[4]; + index_t index_zmm[4]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + if(indexes){ + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); + }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1]); + } opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; load_mask1 = (combined_mask)&0xFF; load_mask2 = (combined_mask >> 8) & 0xFF; - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 16); - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 24); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_four_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::mask_storeu(arr + 16, load_mask1, zmm[2]); - vtype::mask_storeu(arr + 24, load_mask2, zmm[3]); + key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); + key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); + + if(indexes){ + index_zmm[2] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 24); + key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); + bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); + zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); + }else{ + key_zmm[2] = sort_zmm_64bit(key_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3]); + bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); + bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm); + } + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::mask_storeu(keys + 16, load_mask1, key_zmm[2]); + vtype::mask_storeu(keys + 24, load_mask2, key_zmm[3]); } template -X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *arr, int32_t N) +X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 32) { - sort_32_64bit(arr, N); + sort_32_64bit(keys,indexes, N); return; } using zmm_t = typename vtype::zmm_t; using opmask_t = typename vtype::opmask_t; - zmm_t zmm[8]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); + using index_t = zmm_vector::zmm_t; + zmm_t key_zmm[8]; + index_t index_zmm[8]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[2] = vtype::loadu(keys + 16); + key_zmm[3] = vtype::loadu(keys + 24); + if(indexes){ + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[2] = zmm_vector::loadu(indexes + 16); + index_zmm[3] = zmm_vector::loadu(indexes + 24); + key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); + }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3]); + } opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; // N-32 >= 1 @@ -600,57 +1007,108 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *arr, int32_t N) load_mask2 = (combined_mask >> 8) & 0xFF; load_mask3 = (combined_mask >> 16) & 0xFF; load_mask4 = (combined_mask >> 24) & 0xFF; - zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); - zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 40); - zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 48); - zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 56); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_eight_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::mask_storeu(arr + 32, load_mask1, zmm[4]); - vtype::mask_storeu(arr + 40, load_mask2, zmm[5]); - vtype::mask_storeu(arr + 48, load_mask3, zmm[6]); - vtype::mask_storeu(arr + 56, load_mask4, zmm[7]); + key_zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 32); + key_zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 40); + key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); + key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); + + if(indexes){ + index_zmm[4] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4],index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5],index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6],index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7],index_zmm[7]); + bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5],index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7],index_zmm[6], index_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4,index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm,index_zmm); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::storeu(indexes + 16, index_zmm[2]); + zmm_vector::storeu(indexes + 24, index_zmm[3]); + zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); + zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); + zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); + zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); + }else{ + key_zmm[4] = sort_zmm_64bit(key_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7]); + bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); + bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); + bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5]); + bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm); + } + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::storeu(keys + 16, key_zmm[2]); + vtype::storeu(keys + 24, key_zmm[3]); + vtype::mask_storeu(keys + 32, load_mask1, key_zmm[4]); + vtype::mask_storeu(keys + 40, load_mask2, key_zmm[5]); + vtype::mask_storeu(keys + 48, load_mask3, key_zmm[6]); + vtype::mask_storeu(keys + 56, load_mask4, key_zmm[7]); } template -X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *arr, int32_t N) +X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 64) { - sort_64_64bit(arr, N); + sort_64_64bit(keys,indexes, N); return; } using zmm_t = typename vtype::zmm_t; + using index_t = zmm_vector::zmm_t; using opmask_t = typename vtype::opmask_t; - zmm_t zmm[16]; - zmm[0] = vtype::loadu(arr); - zmm[1] = vtype::loadu(arr + 8); - zmm[2] = vtype::loadu(arr + 16); - zmm[3] = vtype::loadu(arr + 24); - zmm[4] = vtype::loadu(arr + 32); - zmm[5] = vtype::loadu(arr + 40); - zmm[6] = vtype::loadu(arr + 48); - zmm[7] = vtype::loadu(arr + 56); - zmm[0] = sort_zmm_64bit(zmm[0]); - zmm[1] = sort_zmm_64bit(zmm[1]); - zmm[2] = sort_zmm_64bit(zmm[2]); - zmm[3] = sort_zmm_64bit(zmm[3]); - zmm[4] = sort_zmm_64bit(zmm[4]); - zmm[5] = sort_zmm_64bit(zmm[5]); - zmm[6] = sort_zmm_64bit(zmm[6]); - zmm[7] = sort_zmm_64bit(zmm[7]); + zmm_t key_zmm[16]; + index_t index_zmm[16]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[2] = vtype::loadu(keys + 16); + key_zmm[3] = vtype::loadu(keys + 24); + key_zmm[4] = vtype::loadu(keys + 32); + key_zmm[5] = vtype::loadu(keys + 40); + key_zmm[6] = vtype::loadu(keys + 48); + key_zmm[7] = vtype::loadu(keys + 56); + if(indexes!=NULL){ + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[2] = zmm_vector::loadu(indexes + 16); + index_zmm[3] = zmm_vector::loadu(indexes + 24); + index_zmm[4] = zmm_vector::loadu(indexes + 32); + index_zmm[5] = zmm_vector::loadu(indexes + 40); + index_zmm[6] = zmm_vector::loadu(indexes + 48); + index_zmm[7] = zmm_vector::loadu(indexes + 56); + key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4],index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5],index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6],index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7],index_zmm[7]); + }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7]); + } + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; @@ -666,63 +1124,117 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *arr, int32_t N) load_mask7 = (combined_mask >> 48) & 0xFF; load_mask8 = (combined_mask >> 56) & 0xFF; } - zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); - zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 72); - zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 80); - zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 88); - zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 96); - zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 104); - zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 112); - zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 120); - zmm[8] = sort_zmm_64bit(zmm[8]); - zmm[9] = sort_zmm_64bit(zmm[9]); - zmm[10] = sort_zmm_64bit(zmm[10]); - zmm[11] = sort_zmm_64bit(zmm[11]); - zmm[12] = sort_zmm_64bit(zmm[12]); - zmm[13] = sort_zmm_64bit(zmm[13]); - zmm[14] = sort_zmm_64bit(zmm[14]); - zmm[15] = sort_zmm_64bit(zmm[15]); - bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); - bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); - bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); - bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); - bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); - bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); - bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); - bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); - bitonic_merge_four_zmm_64bit(zmm); - bitonic_merge_four_zmm_64bit(zmm + 4); - bitonic_merge_four_zmm_64bit(zmm + 8); - bitonic_merge_four_zmm_64bit(zmm + 12); - bitonic_merge_eight_zmm_64bit(zmm); - bitonic_merge_eight_zmm_64bit(zmm + 8); - bitonic_merge_sixteen_zmm_64bit(zmm); - vtype::storeu(arr, zmm[0]); - vtype::storeu(arr + 8, zmm[1]); - vtype::storeu(arr + 16, zmm[2]); - vtype::storeu(arr + 24, zmm[3]); - vtype::storeu(arr + 32, zmm[4]); - vtype::storeu(arr + 40, zmm[5]); - vtype::storeu(arr + 48, zmm[6]); - vtype::storeu(arr + 56, zmm[7]); - vtype::mask_storeu(arr + 64, load_mask1, zmm[8]); - vtype::mask_storeu(arr + 72, load_mask2, zmm[9]); - vtype::mask_storeu(arr + 80, load_mask3, zmm[10]); - vtype::mask_storeu(arr + 88, load_mask4, zmm[11]); - vtype::mask_storeu(arr + 96, load_mask5, zmm[12]); - vtype::mask_storeu(arr + 104, load_mask6, zmm[13]); - vtype::mask_storeu(arr + 112, load_mask7, zmm[14]); - vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); + key_zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 64); + key_zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 72); + key_zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 80); + key_zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 88); + key_zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, keys + 96); + key_zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, keys + 104); + key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); + key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); + + if(indexes){ + index_zmm[8] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8],index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9],index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10],index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11],index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12],index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13],index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14],index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15],index_zmm[15]); + + bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5],index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7],index_zmm[6], index_zmm[7]); + bitonic_merge_two_zmm_64bit(key_zmm[8], key_zmm[9],index_zmm[8], index_zmm[9]); + bitonic_merge_two_zmm_64bit(key_zmm[10], key_zmm[11],index_zmm[10], index_zmm[11]); + bitonic_merge_two_zmm_64bit(key_zmm[12], key_zmm[13],index_zmm[12], index_zmm[13]); + bitonic_merge_two_zmm_64bit(key_zmm[14], key_zmm[15],index_zmm[14], index_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4,index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8,index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12,index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm,index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8,index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm,index_zmm); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::storeu(indexes + 16, index_zmm[2]); + zmm_vector::storeu(indexes + 24, index_zmm[3]); + zmm_vector::storeu(indexes + 32, index_zmm[4]); + zmm_vector::storeu(indexes + 40, index_zmm[5]); + zmm_vector::storeu(indexes + 48, index_zmm[6]); + zmm_vector::storeu(indexes + 56, index_zmm[7]); + zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); + zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); + zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); + zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); + zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); + zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); + zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); + zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); + }else{ + key_zmm[8] = sort_zmm_64bit(key_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15]); + bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); + bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); + bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5]); + bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7]); + bitonic_merge_two_zmm_64bit(key_zmm[8], key_zmm[9]); + bitonic_merge_two_zmm_64bit(key_zmm[10], key_zmm[11]); + bitonic_merge_two_zmm_64bit(key_zmm[12], key_zmm[13]); + bitonic_merge_two_zmm_64bit(key_zmm[14], key_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm); + } + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::storeu(keys + 16, key_zmm[2]); + vtype::storeu(keys + 24, key_zmm[3]); + vtype::storeu(keys + 32, key_zmm[4]); + vtype::storeu(keys + 40, key_zmm[5]); + vtype::storeu(keys + 48, key_zmm[6]); + vtype::storeu(keys + 56, key_zmm[7]); + vtype::mask_storeu(keys + 64, load_mask1, key_zmm[8]); + vtype::mask_storeu(keys + 72, load_mask2, key_zmm[9]); + vtype::mask_storeu(keys + 80, load_mask3, key_zmm[10]); + vtype::mask_storeu(keys + 88, load_mask4, key_zmm[11]); + vtype::mask_storeu(keys + 96, load_mask5, key_zmm[12]); + vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); + vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); + vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); + } template -X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *arr, +X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *keys, uint64_t *indexes, const int64_t left, const int64_t right) { // median of 8 int64_t size = (right - left) / 8; using zmm_t = typename vtype::zmm_t; + using index_t = zmm_vector::zmm_t; __m512i rand_index = _mm512_set_epi64(left + size, left + 2 * size, left + 3 * size, @@ -731,40 +1243,89 @@ X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *arr, left + 6 * size, left + 7 * size, left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + zmm_t key_vec = vtype::template i64gather(rand_index, keys); + + index_t index_vec; + zmm_t sort; + if(indexes) + { + index_vec=zmm_vector::template i64gather(rand_index, indexes); + sort = sort_zmm_64bit(key_vec,index_vec); + }else{ + //index_vec=vtype::template i64gather(rand_index, indexes); + sort = sort_zmm_64bit(key_vec); + } // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; } +template +inline void +heapify(type_t* keys, uint64_t* indexes, int64_t idx, int64_t size) +{ + int64_t i = idx; + while(true) { + int64_t j = 2 * i + 1; + if (j >= size || j < 0) { + break; + } + int k = j + 1; + if (k < size && keys[j] < keys[k]) { + j = k; + } + if (keys[j] < keys[i]) { + break; + } + std::swap(keys[i], keys[j]); + std::swap(indexes[i], indexes[j]); + i = j; + } +} +template +inline void +heap_sort(type_t* keys, uint64_t* indexes, int64_t size) +{ + for (int64_t i = size / 2 - 1; i >= 0; i--) { + heapify(keys, indexes, i, size); + } + for (int64_t i = size - 1; i > 0; i--) { + std::swap(keys[0], keys[i]); + std::swap(indexes[0], indexes[i]); + heapify(keys, indexes, 0, i); + } +} + template inline void -qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) +qsort_64bit_(type_t *keys,uint64_t *indexes, int64_t left, int64_t right, int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1); + if(indexes)heap_sort(keys + left, indexes + left, right - left + 1); + else std::sort(keys + left, keys + right + 1); return; } /* * Base case: use bitonic networks to sort arrays <= 128 */ if (right + 1 - left <= 128) { - sort_128_64bit(arr + left, (int32_t)(right + 1 - left)); + if(indexes) sort_128_64bit(keys + left, indexes + left, (int32_t)(right + 1 - left)); + else sort_128_64bit(keys + left, (uint64_t*)NULL, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot_64bit(arr, left, right); + type_t pivot = get_pivot_64bit(keys, indexes,left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( - arr, left, right + 1, pivot, &smallest, &biggest); + keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); + qsort_64bit_(keys,indexes, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - qsort_64bit_(arr, pivot_index, right, max_iters - 1); + qsort_64bit_(keys,indexes, pivot_index, right, max_iters - 1); } X86_SIMD_SORT_FORCEINLINE int64_t replace_nan_with_inf(double *arr, @@ -794,31 +1355,31 @@ replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qsort(int64_t *arr, int64_t arrsize) +inline void avx512_qsort(int64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, int64_t>( - arr, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); + keys,indexes, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); } } template <> -void avx512_qsort(uint64_t *arr, int64_t arrsize) +inline void avx512_qsort(uint64_t *keys,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, uint64_t>( - arr, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); + keys,indexes, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); } } template <> -void avx512_qsort(double *arr, int64_t arrsize) +inline void avx512_qsort(double *keys,uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(arr, arrsize); + int64_t nan_count = replace_nan_with_inf(keys, arrsize); qsort_64bit_, double>( - arr, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); - replace_inf_with_nan(arr, arrsize, nan_count); + keys,indexes, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); + replace_inf_with_nan(keys, arrsize, nan_count); } } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 493cd436..09d5c861 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -71,9 +71,12 @@ template struct zmm_vector; + template -void avx512_qsort(T *arr, int64_t arrsize); +inline void avx512_qsort(T *keys, uint64_t *indexes, int64_t arrsize); +using index_t = __m512i; +//using index_type = zmm_vector; /* * COEX == Compare and Exchange two registers by swapping min and max values */ @@ -84,7 +87,20 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } +template > +static void COEX(mm_t &key1, mm_t &key2,index_t &index1, index_t &index2) +{ + //COEX(key1,key2); + mm_t key_t1=vtype::min(key1,key2); + mm_t key_t2=vtype::max(key1,key2); + + index_t index_t1=index_type::mask_mov(index2,vtype::eq(key_t1,key1),index1); + index_t index_t2=index_type::mask_mov(index1,vtype::eq(key_t1,key1),index2); + key1=key_t1;key2=key_t2; + index1=index_t1;index2=index_t2; + +} template @@ -94,7 +110,16 @@ static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) zmm_t max = vtype::max(in2, in1); return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max } - +template > +static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2,index_t & indexes1,index_t indexes2, opmask_t mask) +{ + zmm_t tmp_keys=cmp_merge(in1,in2,mask); + indexes1=index_type::mask_mov(indexes2,vtype::eq(tmp_keys, in1),indexes1); + return tmp_keys; // 0 -> min, 1 -> max +} /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -119,7 +144,32 @@ static inline int32_t partition_vec(type_t *arr, *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; } - +template > +static inline int32_t partition_vec(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + const zmm_t keys_vec, + const index_t indexes_vec, + const zmm_t pivot_vec, + zmm_t *smallest_vec, + zmm_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(keys_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + vtype::mask_compressstoreu( + keys + left, vtype::knot_opmask(gt_mask), keys_vec); + vtype::mask_compressstoreu( + keys + right - amount_gt_pivot, gt_mask, keys_vec); + index_type::mask_compressstoreu( + indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); + index_type::mask_compressstoreu( + indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + *smallest_vec = vtype::min(keys_vec, *smallest_vec); + *biggest_vec = vtype::max(keys_vec, *biggest_vec); + return amount_gt_pivot; +} /* * Parition an array based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -223,4 +273,167 @@ static inline int64_t partition_avx512(type_t *arr, *biggest = vtype::reducemax(max_vec); return l_store; } + +template > +static inline int64_t partition_avx512(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, keys[left]); + *biggest = std::max(*biggest, keys[left]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + if(indexes) std::swap(indexes[left], indexes[right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + zmm_t keys_vec = vtype::loadu(keys + left); + int32_t amount_gt_pivot; + if(indexes) { + index_t indexes_vec = index_type::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + }else{ + amount_gt_pivot = partition_vec(keys, + left, + left + vtype::numlanes, + keys_vec, + pivot_vec, + &min_vec, + &max_vec); + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + zmm_t keys_vec_left = vtype::loadu(keys + left); + zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); + index_t indexes_vec_left; + index_t indexes_vec_right; + if(indexes){ + indexes_vec_left = index_type::loadu(indexes + left); + indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); + } + + + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + zmm_t keys_vec; + index_t indexes_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + keys_vec = vtype::loadu(keys + right); + if(indexes) indexes_vec = index_type::loadu(indexes + right); + } + else { + keys_vec = vtype::loadu(keys + left); + if(indexes) indexes_vec = index_type::loadu(indexes + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot; + if(indexes) + amount_gt_pivot= partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + else amount_gt_pivot= partition_vec(keys, + l_store, + r_store + vtype::numlanes, + keys_vec, + pivot_vec, + &min_vec, + &max_vec); + + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot; + if(indexes){ + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + l_store + vtype::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + }else{ + amount_gt_pivot = partition_vec(keys, + l_store, + r_store + vtype::numlanes, + keys_vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(keys, + l_store, + l_store + vtype::numlanes, + keys_vec_right, + pivot_vec, + &min_vec, + &max_vec); + } + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} #endif // AVX512_QSORT_COMMON diff --git a/src/avx512-qsort-key-value.hpp b/src/avx512-qsort-key-value.hpp deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_all.cpp b/tests/test_all.cpp index 6d82a35b..6309c7bc 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -34,7 +34,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) sortedarr = arr; /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); - avx512_qsort(arr.data(), arr.size()); + avx512_qsort(arr.data(),NULL, arr.size()); ASSERT_EQ(sortedarr, arr); arr.clear(); sortedarr.clear(); From a05f0625e2411a1924d699a8af6859140a343215 Mon Sep 17 00:00:00 2001 From: ruclz Date: Mon, 13 Feb 2023 14:08:11 +0800 Subject: [PATCH 03/16] first --- .gitignore | 35 ++ .vscode/settings.json | 27 +- Makefile | 14 +- Makefile.bak | 27 ++ src/avx512-16bit-qsort.hpp | 4 +- src/avx512-32bit-qsort.hpp | 8 +- src/avx512-64bit-qsort.hpp | 931 ++++++++++++++++++++++--------------- src/avx512-common-qsort.h | 209 +++++---- tests/meson.build | 30 +- tests/test_all.cpp | 46 +- utils/rand_array.h | 27 +- 11 files changed, 830 insertions(+), 528 deletions(-) create mode 100644 .gitignore create mode 100644 Makefile.bak diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..669f6732 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +**/.vscode + diff --git a/.vscode/settings.json b/.vscode/settings.json index 5abdc8c3..f1aaf8a2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,14 +1,5 @@ { "files.associations": { - "*.tcc": "cpp", - "functional": "cpp", - "string_view": "cpp", - "random": "cpp", - "istream": "cpp", - "limits": "cpp", - "algorithm": "cpp", - "bit": "cpp", - "numeric": "cpp", "cctype": "cpp", "clocale": "cpp", "cmath": "cpp", @@ -22,16 +13,25 @@ "cwctype": "cpp", "array": "cpp", "atomic": "cpp", + "bit": "cpp", + "*.tcc": "cpp", "cstdint": "cpp", "deque": "cpp", + "map": "cpp", + "set": "cpp", "unordered_map": "cpp", "vector": "cpp", "exception": "cpp", + "algorithm": "cpp", + "functional": "cpp", "iterator": "cpp", "memory": "cpp", "memory_resource": "cpp", + "numeric": "cpp", "optional": "cpp", + "random": "cpp", "string": "cpp", + "string_view": "cpp", "system_error": "cpp", "tuple": "cpp", "type_traits": "cpp", @@ -41,17 +41,14 @@ "iomanip": "cpp", "iosfwd": "cpp", "iostream": "cpp", + "istream": "cpp", + "limits": "cpp", "new": "cpp", "ostream": "cpp", "sstream": "cpp", "stdexcept": "cpp", "streambuf": "cpp", "cinttypes": "cpp", - "typeinfo": "cpp", - "compare": "cpp", - "concepts": "cpp", - "numbers": "cpp", - "map": "cpp", - "set": "cpp" + "typeinfo": "cpp" } } \ No newline at end of file diff --git a/Makefile b/Makefile index 04bf5d78..899463a3 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,5 @@ -CXX ?= g++ -SRCDIR = ./src -TESTDIR = ./tests -BENCHDIR = ./benchmarks -UTILS = ./utils -SRCS = $(wildcard $(SRCDIR)/*.hpp) +CXX ? = g++ SRCDIR =./ src TESTDIR =./ tests BENCHDIR =./ benchmarks UTILS + =./ utils SRCS = $(wildcard $(SRCDIR)/*.hpp) TESTS = $(wildcard $(TESTDIR)/*.cpp) TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS)) TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS)) @@ -15,13 +11,13 @@ LD_FLAGS = -L /usr/local/lib -l $(GTEST_LIB) -l pthread all : test bench $(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS) - $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -l $(GTEST_LIB) -c $< -o $@ + $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -c $< -o $@ test: $(TESTDIR)/main.cpp $(TESTOBJS) $(SRCS) - $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe + $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe bench: $(BENCHDIR)/main.cpp $(SRCS) $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe clean: - rm -f $(TESTDIR)/*.o testexe benchexe + rm -f $(TESTDIR)/*.o testexe benchexe \ No newline at end of file diff --git a/Makefile.bak b/Makefile.bak new file mode 100644 index 00000000..07c7818d --- /dev/null +++ b/Makefile.bak @@ -0,0 +1,27 @@ +CXX ?= g++ +SRCDIR = ./src +TESTDIR = ./tests +BENCHDIR = ./benchmarks +UTILS = ./utils +SRCS = $(wildcard $(SRCDIR)/*.hpp) +TESTS = $(wildcard $(TESTDIR)/*.cpp) +TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS)) +TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS)) +GTEST_LIB = gtest +GTEST_INCLUDE = /usr/local/include +CXXFLAGS += -I$(SRCDIR) -I$(GTEST_INCLUDE) -I$(UTILS) +LD_FLAGS = -L /usr/local/lib -l $(GTEST_LIB) -l pthread + +all : test bench + +$(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS) + $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -c $< -o $@ + +test: $(TESTDIR)/main.cpp $(TESTOBJS) $(SRCS) + $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe + +bench: $(BENCHDIR)/main.cpp $(SRCS) + $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe + +clean: + rm -f $(TESTDIR)/*.o testexe benchexe \ No newline at end of file diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index cef86a2f..6e909336 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -484,7 +484,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template <> -inline void avx512_qsort(int16_t *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(int16_t *arr, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, int16_t>( @@ -493,7 +493,7 @@ inline void avx512_qsort(int16_t *arr,uint64_t *indexes, int64_t arrsize) } template <> -inline void avx512_qsort(uint16_t *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(uint16_t *arr, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, uint16_t>( diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 9b8bcea5..fe6e49d8 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -684,7 +684,8 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int32_t *arr,uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(int32_t *arr, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, int32_t>( @@ -693,7 +694,8 @@ inline void avx512_qsort(int32_t *arr,uint64_t *indexes, int64_t arrsiz } template <> -inline void avx512_qsort(uint32_t *arr,uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(uint32_t *arr, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, uint32_t>( @@ -702,7 +704,7 @@ inline void avx512_qsort(uint32_t *arr,uint64_t *indexes, int64_t arrs } template <> -inline void avx512_qsort(float *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(float *arr, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(arr, arrsize); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 6176adf1..f1d02e17 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -361,33 +361,53 @@ X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t zmm) return zmm; } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t key_zmm,index_t &index_zmm) +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, + index_t &index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); key_zmm = cmp_merge( - key_zmm,vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), - index_zmm,zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), + index_zmm), 0xCC); key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); key_zmm = cmp_merge( - key_zmm, vtype::permutexvar(rev_index, key_zmm), - index_zmm,zmm_vector::permutexvar(rev_index, index_zmm), + key_zmm, + vtype::permutexvar(rev_index, key_zmm), + index_zmm, + zmm_vector::permutexvar(rev_index, index_zmm), 0xF0); key_zmm = cmp_merge( - key_zmm,vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), - index_zmm,zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), 0xCC); key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); return key_zmm; } @@ -412,8 +432,11 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) return zmm; } // Assumes zmm is bitonic and performs a recursive half cleaner -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm,zmm_vector::zmm_t &index_zmm) +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE zmm_t +bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) { // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 @@ -421,19 +444,24 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm,zmm_vector key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), + index_zmm), 0xF0); // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), 0xCC); // 3) half_cleaner[1] key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); return key_zmm; } @@ -452,9 +480,12 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm2 = bitonic_merge_zmm_64bit(zmm4); } // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template ::zmm_t> +template ::zmm_t> X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, - zmm_t &key_zmm2, index_t &index_zmm1, + zmm_t &key_zmm2, + index_t &index_zmm1, index_t &index_zmm2) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); @@ -462,18 +493,19 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); - zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); - index_t index_zmm3=zmm_vector::mask_mov(index_zmm2,vtype::eq(key_zmm3,key_zmm1),index_zmm1); - index_t index_zmm4=zmm_vector::mask_mov(index_zmm1,vtype::eq(key_zmm3,key_zmm1),index_zmm2); + index_t index_zmm3 = zmm_vector::mask_mov( + index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); + index_t index_zmm4 = zmm_vector::mask_mov( + index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3,index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4,index_zmm4); - index_zmm1=index_zmm3; - index_zmm2=index_zmm4; + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + index_zmm1 = index_zmm3; + index_zmm2 = index_zmm4; } // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner @@ -498,54 +530,66 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) zmm[2] = bitonic_merge_zmm_64bit(zmm2); zmm[3] = bitonic_merge_zmm_64bit(zmm3); } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); - index_t index_zmm2r = zmm_vector::permutexvar(rev_index, index_zmm[2]); - index_t index_zmm3r = zmm_vector::permutexvar(rev_index, index_zmm[3]); - + index_t index_zmm2r + = zmm_vector::permutexvar(rev_index, index_zmm[2]); + index_t index_zmm3r + = zmm_vector::permutexvar(rev_index, index_zmm[3]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); - - index_t index_zmm_t1=zmm_vector::mask_mov(index_zmm3r,vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); - index_t index_zmm_m1=zmm_vector::mask_mov(index_zmm[0],vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm3r); - index_t index_zmm_t2=zmm_vector::mask_mov(index_zmm2r,vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); - index_t index_zmm_m2=zmm_vector::mask_mov(index_zmm[1],vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm2r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); + index_t index_zmm_t1 = zmm_vector::mask_mov( + index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_t index_zmm_t2 = zmm_vector::mask_mov( + index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); // 2) Recursive half clearer: 16 zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t3 = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t4 = zmm_vector::permutexvar(rev_index, index_zmm_m1); + index_t index_zmm_t3 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t4 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); - index_t index_zmm0 = zmm_vector::mask_mov(index_zmm_t2,vtype::eq(key_zmm0,key_zmm_t1),index_zmm_t1); - index_t index_zmm1 = zmm_vector::mask_mov(index_zmm_t1,vtype::eq(key_zmm0,key_zmm_t1),index_zmm_t2); - index_t index_zmm2 = zmm_vector::mask_mov(index_zmm_t4,vtype::eq(key_zmm2,key_zmm_t3),index_zmm_t3); - index_t index_zmm3 = zmm_vector::mask_mov(index_zmm_t3,vtype::eq(key_zmm2,key_zmm_t3),index_zmm_t4); - - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0,index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1,index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2,index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3,index_zmm3); - - index_zmm[0]=index_zmm0; - index_zmm[1]=index_zmm1; - index_zmm[2]=index_zmm2; - index_zmm[3]=index_zmm3; + index_t index_zmm0 = zmm_vector::mask_mov( + index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_t index_zmm1 = zmm_vector::mask_mov( + index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_t index_zmm2 = zmm_vector::mask_mov( + index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_t index_zmm3 = zmm_vector::mask_mov( + index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + + index_zmm[0] = index_zmm0; + index_zmm[1] = index_zmm1; + index_zmm[2] = index_zmm2; + index_zmm[3] = index_zmm3; } template X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) @@ -580,19 +624,25 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); - index_t index_zmm4r = zmm_vector::permutexvar(rev_index, index_zmm[4]); - index_t index_zmm5r = zmm_vector::permutexvar(rev_index, index_zmm[5]); - index_t index_zmm6r = zmm_vector::permutexvar(rev_index, index_zmm[6]); - index_t index_zmm7r = zmm_vector::permutexvar(rev_index, index_zmm[7]); - + index_t index_zmm4r + = zmm_vector::permutexvar(rev_index, index_zmm[4]); + index_t index_zmm5r + = zmm_vector::permutexvar(rev_index, index_zmm[5]); + index_t index_zmm6r + = zmm_vector::permutexvar(rev_index, index_zmm[6]); + index_t index_zmm7r + = zmm_vector::permutexvar(rev_index, index_zmm[7]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); @@ -602,56 +652,63 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,inde zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); - - index_t index_zmm_t1= zmm_vector::mask_mov(index_zmm7r, vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); - index_t index_zmm_m1= zmm_vector::mask_mov(index_zmm[0], vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm7r); - index_t index_zmm_t2= zmm_vector::mask_mov(index_zmm6r, vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); - index_t index_zmm_m2= zmm_vector::mask_mov(index_zmm[1], vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm6r); - index_t index_zmm_t3= zmm_vector::mask_mov(index_zmm5r, vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm[2]); - index_t index_zmm_m3= zmm_vector::mask_mov(index_zmm[2], vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm5r); - index_t index_zmm_t4= zmm_vector::mask_mov(index_zmm4r, vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm[3]); - index_t index_zmm_m4= zmm_vector::mask_mov(index_zmm[3], vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm4r); - - + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); + + index_t index_zmm_t1 = zmm_vector::mask_mov( + index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_t index_zmm_t2 = zmm_vector::mask_mov( + index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_t index_zmm_t3 = zmm_vector::mask_mov( + index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_t index_zmm_m3 = zmm_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_t index_zmm_t4 = zmm_vector::mask_mov( + index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_t index_zmm_m4 = zmm_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t5 = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t6 = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t7 = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t8 = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - - COEX(key_zmm_t1, key_zmm_t3,index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4,index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7,index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8,index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2,index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4,index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6,index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8,index_zmm_t7, index_zmm_t8); - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1,index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2,index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3,index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4,index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5,index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6,index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7,index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8,index_zmm_t8); - - index_zmm[0]=index_zmm_t1; - index_zmm[1]=index_zmm_t2; - index_zmm[2]=index_zmm_t3; - index_zmm[3]=index_zmm_t4; - index_zmm[4]=index_zmm_t5; - index_zmm[5]=index_zmm_t6; - index_zmm[6]=index_zmm_t7; - index_zmm[7]=index_zmm_t8; - - + index_t index_zmm_t5 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t6 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t7 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t8 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; } template X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) @@ -726,8 +783,11 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +template ::zmm_t> +X86_SIMD_SORT_FORCEINLINE void +bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_t *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); @@ -739,15 +799,23 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,in zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); - index_t index_zmm8r = zmm_vector::permutexvar(rev_index, index_zmm[8]); - index_t index_zmm9r = zmm_vector::permutexvar(rev_index, index_zmm[9]); - index_t index_zmm10r = zmm_vector::permutexvar(rev_index, index_zmm[10]); - index_t index_zmm11r = zmm_vector::permutexvar(rev_index, index_zmm[11]); - index_t index_zmm12r = zmm_vector::permutexvar(rev_index, index_zmm[12]); - index_t index_zmm13r = zmm_vector::permutexvar(rev_index, index_zmm[13]); - index_t index_zmm14r = zmm_vector::permutexvar(rev_index, index_zmm[14]); - index_t index_zmm15r = zmm_vector::permutexvar(rev_index, index_zmm[15]); - + index_t index_zmm8r + = zmm_vector::permutexvar(rev_index, index_zmm[8]); + index_t index_zmm9r + = zmm_vector::permutexvar(rev_index, index_zmm[9]); + index_t index_zmm10r + = zmm_vector::permutexvar(rev_index, index_zmm[10]); + index_t index_zmm11r + = zmm_vector::permutexvar(rev_index, index_zmm[11]); + index_t index_zmm12r + = zmm_vector::permutexvar(rev_index, index_zmm[12]); + index_t index_zmm13r + = zmm_vector::permutexvar(rev_index, index_zmm[13]); + index_t index_zmm14r + = zmm_vector::permutexvar(rev_index, index_zmm[14]); + index_t index_zmm15r + = zmm_vector::permutexvar(rev_index, index_zmm[15]); + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); @@ -764,26 +832,41 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,in zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); - - index_t index_zmm_t1=zmm_vector::mask_mov(index_zmm15r,vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); - index_t index_zmm_m1=zmm_vector::mask_mov(index_zmm[0],vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm15r); - index_t index_zmm_t2=zmm_vector::mask_mov(index_zmm14r,vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); - index_t index_zmm_m2=zmm_vector::mask_mov(index_zmm[1],vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm14r); - index_t index_zmm_t3=zmm_vector::mask_mov(index_zmm13r,vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm[2]); - index_t index_zmm_m3=zmm_vector::mask_mov(index_zmm[2],vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm13r); - index_t index_zmm_t4=zmm_vector::mask_mov(index_zmm12r,vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm[3]); - index_t index_zmm_m4=zmm_vector::mask_mov(index_zmm[3],vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm12r); - - index_t index_zmm_t5=zmm_vector::mask_mov(index_zmm11r,vtype::eq(key_zmm_t5,key_zmm[4]),index_zmm[4]); - index_t index_zmm_m5=zmm_vector::mask_mov(index_zmm[4],vtype::eq(key_zmm_t5,key_zmm[4]),index_zmm11r); - index_t index_zmm_t6=zmm_vector::mask_mov(index_zmm10r,vtype::eq(key_zmm_t6,key_zmm[5]),index_zmm[5]); - index_t index_zmm_m6=zmm_vector::mask_mov(index_zmm[5],vtype::eq(key_zmm_t6,key_zmm[5]),index_zmm10r); - index_t index_zmm_t7=zmm_vector::mask_mov(index_zmm9r,vtype::eq(key_zmm_t7,key_zmm[6]),index_zmm[6]); - index_t index_zmm_m7=zmm_vector::mask_mov(index_zmm[6],vtype::eq(key_zmm_t7,key_zmm[6]),index_zmm9r); - index_t index_zmm_t8=zmm_vector::mask_mov(index_zmm8r,vtype::eq(key_zmm_t8,key_zmm[7]),index_zmm[7]); - index_t index_zmm_m8=zmm_vector::mask_mov(index_zmm[7],vtype::eq(key_zmm_t8,key_zmm[7]),index_zmm8r); - + zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); + + index_t index_zmm_t1 = zmm_vector::mask_mov( + index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_t index_zmm_t2 = zmm_vector::mask_mov( + index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_t index_zmm_t3 = zmm_vector::mask_mov( + index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_t index_zmm_m3 = zmm_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_t index_zmm_t4 = zmm_vector::mask_mov( + index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_t index_zmm_m4 = zmm_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_t index_zmm_t5 = zmm_vector::mask_mov( + index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_t index_zmm_m5 = zmm_vector::mask_mov( + index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_t index_zmm_t6 = zmm_vector::mask_mov( + index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_t index_zmm_m6 = zmm_vector::mask_mov( + index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_t index_zmm_t7 = zmm_vector::mask_mov( + index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_t index_zmm_m7 = zmm_vector::mask_mov( + index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_t index_zmm_t8 = zmm_vector::mask_mov( + index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_t index_zmm_m8 = zmm_vector::mask_mov( + index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); @@ -793,98 +876,110 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,in zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t9 = zmm_vector::permutexvar(rev_index, index_zmm_m8); - index_t index_zmm_t10 = zmm_vector::permutexvar(rev_index, index_zmm_m7); - index_t index_zmm_t11 = zmm_vector::permutexvar(rev_index, index_zmm_m6); - index_t index_zmm_t12 = zmm_vector::permutexvar(rev_index, index_zmm_m5); - index_t index_zmm_t13 = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t14 = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t15 = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t16 = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5,index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6,index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7,index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8,index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13,index_zmm_t9, index_zmm_t13); - COEX(key_zmm_t10, key_zmm_t14,index_zmm_t10, index_zmm_t14); - COEX(key_zmm_t11, key_zmm_t15,index_zmm_t11, index_zmm_t15); - COEX(key_zmm_t12, key_zmm_t16,index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3,index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4,index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7,index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8,index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11,index_zmm_t9, index_zmm_t11); - COEX(key_zmm_t10, key_zmm_t12,index_zmm_t10, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t15,index_zmm_t13, index_zmm_t15); - COEX(key_zmm_t14, key_zmm_t16,index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2,index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4,index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6,index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8,index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10,index_zmm_t9, index_zmm_t10); - COEX(key_zmm_t11, key_zmm_t12,index_zmm_t11, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t14,index_zmm_t13, index_zmm_t14); - COEX(key_zmm_t15, key_zmm_t16,index_zmm_t15, index_zmm_t16); + index_t index_zmm_t9 + = zmm_vector::permutexvar(rev_index, index_zmm_m8); + index_t index_zmm_t10 + = zmm_vector::permutexvar(rev_index, index_zmm_m7); + index_t index_zmm_t11 + = zmm_vector::permutexvar(rev_index, index_zmm_m6); + index_t index_zmm_t12 + = zmm_vector::permutexvar(rev_index, index_zmm_m5); + index_t index_zmm_t13 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t14 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t15 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t16 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); // - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1,index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2,index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3,index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4,index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5,index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6,index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7,index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8,index_zmm_t8); - key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9,index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10,index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11,index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12,index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13,index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14,index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15,index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16,index_zmm_t16); - - index_zmm[0]=index_zmm_t1; - index_zmm[1]=index_zmm_t2; - index_zmm[2]=index_zmm_t3; - index_zmm[3]=index_zmm_t4; - index_zmm[4]=index_zmm_t5; - index_zmm[5]=index_zmm_t6; - index_zmm[6]=index_zmm_t7; - index_zmm[7]=index_zmm_t8; - index_zmm[8]=index_zmm_t9; - index_zmm[9]=index_zmm_t10; - index_zmm[10]=index_zmm_t11; - index_zmm[11]=index_zmm_t12; - index_zmm[12]=index_zmm_t13; - index_zmm[13]=index_zmm_t14; - index_zmm[14]=index_zmm_t15; - index_zmm[15]=index_zmm_t16; + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; + index_zmm[8] = index_zmm_t9; + index_zmm[9] = index_zmm_t10; + index_zmm[10] = index_zmm_t11; + index_zmm[11] = index_zmm_t12; + index_zmm[12] = index_zmm_t13; + index_zmm[13] = index_zmm_t14; + index_zmm[14] = index_zmm_t15; + index_zmm[15] = index_zmm_t16; } template -X86_SIMD_SORT_FORCEINLINE void sort_8_64bit(type_t *keys,uint64_t *indexes, int32_t N) +X86_SIMD_SORT_FORCEINLINE void +sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) { typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; typename vtype::zmm_t key_zmm = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); - if(indexes){ + if (indexes) { zmm_vector::zmm_t index_zmm - = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask, indexes); - vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm,index_zmm)); - zmm_vector::mask_storeu(indexes, load_mask,index_zmm); - }else{ + = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes); + vtype::mask_storeu( + keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); + zmm_vector::mask_storeu(indexes, load_mask, index_zmm); + } + else { vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm)); - } - + } } template -X86_SIMD_SORT_FORCEINLINE void sort_16_64bit(type_t *keys,uint64_t *indexes, int32_t N) +X86_SIMD_SORT_FORCEINLINE void +sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 8) { - sort_8_64bit(keys,indexes, N); + sort_8_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; @@ -893,26 +988,30 @@ X86_SIMD_SORT_FORCEINLINE void sort_16_64bit(type_t *keys,uint64_t *indexes, int typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); - if(indexes){ + if (indexes) { index_t index_zmm1 = zmm_vector::loadu(indexes); - index_t index_zmm2 = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask, indexes + 8); - key_zmm1 = sort_zmm_64bit(key_zmm1,index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2,index_zmm2); - bitonic_merge_two_zmm_64bit(key_zmm1,key_zmm2,index_zmm1,index_zmm2); + index_t index_zmm2 = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes + 8); + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( + key_zmm1, key_zmm2, index_zmm1, index_zmm2); zmm_vector::storeu(indexes, index_zmm1); zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); - }else{ + } + else { key_zmm1 = sort_zmm_64bit(key_zmm1); key_zmm2 = sort_zmm_64bit(key_zmm2); - bitonic_merge_two_zmm_64bit(key_zmm1,key_zmm2); + bitonic_merge_two_zmm_64bit(key_zmm1, key_zmm2); } - + vtype::storeu(keys, key_zmm1); vtype::mask_storeu(keys + 8, load_mask, key_zmm2); } template -X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int32_t N) +X86_SIMD_SORT_FORCEINLINE void +sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 16) { sort_16_64bit(keys, indexes, N); @@ -923,15 +1022,16 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int using index_t = zmm_vector::zmm_t; zmm_t key_zmm[4]; index_t index_zmm[4]; - + key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); - if(indexes){ + if (indexes) { index_zmm[0] = zmm_vector::loadu(indexes); index_zmm[1] = zmm_vector::loadu(indexes + 8); - key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); - }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + } + else { key_zmm[0] = sort_zmm_64bit(key_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1]); } @@ -942,24 +1042,31 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); - if(indexes){ - index_zmm[2] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); + if (indexes) { + index_zmm[2] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 24); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); zmm_vector::storeu(indexes, index_zmm[0]); zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - }else{ + zmm_vector::mask_storeu( + indexes + 16, load_mask1, index_zmm[2]); + zmm_vector::mask_storeu( + indexes + 24, load_mask2, index_zmm[3]); + } + else { key_zmm[2] = sort_zmm_64bit(key_zmm[2]); key_zmm[3] = sort_zmm_64bit(key_zmm[3]); bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm); + bitonic_merge_four_zmm_64bit(key_zmm); } vtype::storeu(keys, key_zmm[0]); vtype::storeu(keys + 8, key_zmm[1]); @@ -968,10 +1075,11 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int } template -X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_FORCEINLINE void +sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 32) { - sort_32_64bit(keys,indexes, N); + sort_32_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; @@ -979,21 +1087,22 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in using index_t = zmm_vector::zmm_t; zmm_t key_zmm[8]; index_t index_zmm[8]; - + key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); key_zmm[2] = vtype::loadu(keys + 16); key_zmm[3] = vtype::loadu(keys + 24); - if(indexes){ + if (indexes) { index_zmm[0] = zmm_vector::loadu(indexes); index_zmm[1] = zmm_vector::loadu(indexes + 8); index_zmm[2] = zmm_vector::loadu(indexes + 16); index_zmm[3] = zmm_vector::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); - }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + } + else { key_zmm[0] = sort_zmm_64bit(key_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1]); key_zmm[2] = sort_zmm_64bit(key_zmm[2]); @@ -1012,31 +1121,44 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); - if(indexes){ - index_zmm[4] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4],index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5],index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6],index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7],index_zmm[7]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5],index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7],index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4,index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm,index_zmm); + if (indexes) { + index_zmm[4] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); zmm_vector::storeu(indexes, index_zmm[0]); zmm_vector::storeu(indexes + 8, index_zmm[1]); zmm_vector::storeu(indexes + 16, index_zmm[2]); zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - }else{ + zmm_vector::mask_storeu( + indexes + 32, load_mask1, index_zmm[4]); + zmm_vector::mask_storeu( + indexes + 40, load_mask2, index_zmm[5]); + zmm_vector::mask_storeu( + indexes + 48, load_mask3, index_zmm[6]); + zmm_vector::mask_storeu( + indexes + 56, load_mask4, index_zmm[7]); + } + else { key_zmm[4] = sort_zmm_64bit(key_zmm[4]); key_zmm[5] = sort_zmm_64bit(key_zmm[5]); key_zmm[6] = sort_zmm_64bit(key_zmm[6]); @@ -1047,7 +1169,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7]); bitonic_merge_four_zmm_64bit(key_zmm); bitonic_merge_four_zmm_64bit(key_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm); } vtype::storeu(keys, key_zmm[0]); @@ -1061,10 +1183,11 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in } template -X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_FORCEINLINE void +sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 64) { - sort_64_64bit(keys,indexes, N); + sort_64_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; @@ -1072,7 +1195,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i using opmask_t = typename vtype::opmask_t; zmm_t key_zmm[16]; index_t index_zmm[16]; - + key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); key_zmm[2] = vtype::loadu(keys + 16); @@ -1081,7 +1204,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i key_zmm[5] = vtype::loadu(keys + 40); key_zmm[6] = vtype::loadu(keys + 48); key_zmm[7] = vtype::loadu(keys + 56); - if(indexes!=NULL){ + if (indexes != NULL) { index_zmm[0] = zmm_vector::loadu(indexes); index_zmm[1] = zmm_vector::loadu(indexes + 8); index_zmm[2] = zmm_vector::loadu(indexes + 16); @@ -1090,15 +1213,16 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i index_zmm[5] = zmm_vector::loadu(indexes + 40); index_zmm[6] = zmm_vector::loadu(indexes + 48); index_zmm[7] = zmm_vector::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4],index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5],index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6],index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7],index_zmm[7]); - }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + } + else { key_zmm[0] = sort_zmm_64bit(key_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1]); key_zmm[2] = sort_zmm_64bit(key_zmm[2]); @@ -1133,39 +1257,55 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); - if(indexes){ - index_zmm[8] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8],index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9],index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10],index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11],index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12],index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13],index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14],index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15],index_zmm[15]); - - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5],index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7],index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit(key_zmm[8], key_zmm[9],index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit(key_zmm[10], key_zmm[11],index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit(key_zmm[12], key_zmm[13],index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit(key_zmm[14], key_zmm[15],index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4,index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8,index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12,index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm,index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8,index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm,index_zmm); + if (indexes) { + index_zmm[8] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); + bitonic_merge_two_zmm_64bit( + key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); + bitonic_merge_two_zmm_64bit( + key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); + bitonic_merge_two_zmm_64bit( + key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); zmm_vector::storeu(indexes, index_zmm[0]); zmm_vector::storeu(indexes + 8, index_zmm[1]); zmm_vector::storeu(indexes + 16, index_zmm[2]); @@ -1174,15 +1314,24 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i zmm_vector::storeu(indexes + 40, index_zmm[5]); zmm_vector::storeu(indexes + 48, index_zmm[6]); zmm_vector::storeu(indexes + 56, index_zmm[7]); - zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - }else{ + zmm_vector::mask_storeu( + indexes + 64, load_mask1, index_zmm[8]); + zmm_vector::mask_storeu( + indexes + 72, load_mask2, index_zmm[9]); + zmm_vector::mask_storeu( + indexes + 80, load_mask3, index_zmm[10]); + zmm_vector::mask_storeu( + indexes + 88, load_mask4, index_zmm[11]); + zmm_vector::mask_storeu( + indexes + 96, load_mask5, index_zmm[12]); + zmm_vector::mask_storeu( + indexes + 104, load_mask6, index_zmm[13]); + zmm_vector::mask_storeu( + indexes + 112, load_mask7, index_zmm[14]); + zmm_vector::mask_storeu( + indexes + 120, load_mask8, index_zmm[15]); + } + else { key_zmm[8] = sort_zmm_64bit(key_zmm[8]); key_zmm[9] = sort_zmm_64bit(key_zmm[9]); key_zmm[10] = sort_zmm_64bit(key_zmm[10]); @@ -1223,11 +1372,11 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); - } template -X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *keys, uint64_t *indexes, +X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *keys, + uint64_t *indexes, const int64_t left, const int64_t right) { @@ -1247,85 +1396,88 @@ X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *keys, uint64_t *indexes index_t index_vec; zmm_t sort; - if(indexes) - { - index_vec=zmm_vector::template i64gather(rand_index, indexes); - sort = sort_zmm_64bit(key_vec,index_vec); - }else{ - //index_vec=vtype::template i64gather(rand_index, indexes); - sort = sort_zmm_64bit(key_vec); + if (indexes) { + index_vec = zmm_vector::template i64gather( + rand_index, indexes); + sort = sort_zmm_64bit(key_vec, index_vec); + } + else { + //index_vec=vtype::template i64gather(rand_index, indexes); + sort = sort_zmm_64bit(key_vec); } // pivot will never be a nan, since there are no nan's! - + return ((type_t *)&sort)[4]; } template -inline void -heapify(type_t* keys, uint64_t* indexes, int64_t idx, int64_t size) +inline void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) { - int64_t i = idx; - while(true) { - int64_t j = 2 * i + 1; - if (j >= size || j < 0) { - break; - } - int k = j + 1; - if (k < size && keys[j] < keys[k]) { - j = k; - } - if (keys[j] < keys[i]) { - break; - } - std::swap(keys[i], keys[j]); - std::swap(indexes[i], indexes[j]); - i = j; - } + int64_t i = idx; + while (true) { + int64_t j = 2 * i + 1; + if (j >= size || j < 0) { break; } + int k = j + 1; + if (k < size && keys[j] < keys[k]) { j = k; } + if (keys[j] < keys[i]) { break; } + std::swap(keys[i], keys[j]); + std::swap(indexes[i], indexes[j]); + i = j; + } } template -inline void -heap_sort(type_t* keys, uint64_t* indexes, int64_t size) +inline void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) { - for (int64_t i = size / 2 - 1; i >= 0; i--) { - heapify(keys, indexes, i, size); - } - for (int64_t i = size - 1; i > 0; i--) { - std::swap(keys[0], keys[i]); - std::swap(indexes[0], indexes[i]); - heapify(keys, indexes, 0, i); - } + for (int64_t i = size / 2 - 1; i >= 0; i--) { + heapify(keys, indexes, i, size); + } + for (int64_t i = size - 1; i > 0; i--) { + std::swap(keys[0], keys[i]); + std::swap(indexes[0], indexes[i]); + heapify(keys, indexes, 0, i); + } } template -inline void -qsort_64bit_(type_t *keys,uint64_t *indexes, int64_t left, int64_t right, int64_t max_iters) +inline void qsort_64bit_(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - if(indexes)heap_sort(keys + left, indexes + left, right - left + 1); - else std::sort(keys + left, keys + right + 1); + if (indexes) + heap_sort(keys + left, indexes + left, right - left + 1); + else + std::sort(keys + left, keys + right + 1); return; } /* * Base case: use bitonic networks to sort arrays <= 128 */ if (right + 1 - left <= 128) { - if(indexes) sort_128_64bit(keys + left, indexes + left, (int32_t)(right + 1 - left)); - else sort_128_64bit(keys + left, (uint64_t*)NULL, (int32_t)(right + 1 - left)); + if (indexes) + sort_128_64bit( + keys + left, indexes + left, (int32_t)(right + 1 - left)); + else + sort_128_64bit( + keys + left, (uint64_t *)NULL, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot_64bit(keys, indexes,left, right); + type_t pivot = get_pivot_64bit(keys, indexes, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - qsort_64bit_(keys,indexes, left, pivot_index - 1, max_iters - 1); + qsort_64bit_( + keys, indexes, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - qsort_64bit_(keys,indexes, pivot_index, right, max_iters - 1); + qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); } X86_SIMD_SORT_FORCEINLINE int64_t replace_nan_with_inf(double *arr, @@ -1355,30 +1507,41 @@ replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int64_t *keys, uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(int64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, int64_t>( - keys,indexes, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); + keys, + indexes, + 0, + arrsize - 1, + 2 * (63 - __builtin_clzll(arrsize))); } } template <> -inline void avx512_qsort(uint64_t *keys,uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(uint64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, uint64_t>( - keys,indexes, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); + keys, + indexes, + 0, + arrsize - 1, + 2 * (63 - __builtin_clzll(arrsize))); } } template <> -inline void avx512_qsort(double *keys,uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(double *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(keys, arrsize); qsort_64bit_, double>( - keys,indexes, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize))); + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(keys, arrsize, nan_count); } } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 09d5c861..79918dbc 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -71,7 +71,6 @@ template struct zmm_vector; - template inline void avx512_qsort(T *keys, uint64_t *indexes, int64_t arrsize); @@ -87,19 +86,24 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } -template > -static void COEX(mm_t &key1, mm_t &key2,index_t &index1, index_t &index2) +template > +static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) { //COEX(key1,key2); - mm_t key_t1=vtype::min(key1,key2); - mm_t key_t2=vtype::max(key1,key2); + mm_t key_t1 = vtype::min(key1, key2); + mm_t key_t2 = vtype::max(key1, key2); - index_t index_t1=index_type::mask_mov(index2,vtype::eq(key_t1,key1),index1); - index_t index_t2=index_type::mask_mov(index1,vtype::eq(key_t1,key1),index2); - - key1=key_t1;key2=key_t2; - index1=index_t1;index2=index_t2; + index_t index_t1 + = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); + index_t index_t2 + = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); + key1 = key_t1; + key2 = key_t2; + index1 = index_t1; + index2 = index_t2; } template > -static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2,index_t & indexes1,index_t indexes2, opmask_t mask) + typename index_type = zmm_vector> +static inline zmm_t cmp_merge(zmm_t in1, + zmm_t in2, + index_t &indexes1, + index_t indexes2, + opmask_t mask) { - zmm_t tmp_keys=cmp_merge(in1,in2,mask); - indexes1=index_type::mask_mov(indexes2,vtype::eq(tmp_keys, in1),indexes1); + zmm_t tmp_keys = cmp_merge(in1, in2, mask); + indexes1 = index_type::mask_mov( + indexes2, vtype::eq(tmp_keys, in1), indexes1); return tmp_keys; // 0 -> min, 1 -> max } /* @@ -144,7 +153,10 @@ static inline int32_t partition_vec(type_t *arr, *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; } -template > +template > static inline int32_t partition_vec(type_t *keys, uint64_t *indexes, int64_t left, @@ -163,7 +175,7 @@ static inline int32_t partition_vec(type_t *keys, vtype::mask_compressstoreu( keys + right - amount_gt_pivot, gt_mask, keys_vec); index_type::mask_compressstoreu( - indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); + indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); index_type::mask_compressstoreu( indexes + right - amount_gt_pivot, gt_mask, indexes_vec); *smallest_vec = vtype::min(keys_vec, *smallest_vec); @@ -274,7 +286,9 @@ static inline int64_t partition_avx512(type_t *arr, return l_store; } -template > +template > static inline int64_t partition_avx512(type_t *keys, uint64_t *indexes, int64_t left, @@ -287,10 +301,10 @@ static inline int64_t partition_avx512(type_t *keys, for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { *smallest = std::min(*smallest, keys[left]); *biggest = std::max(*biggest, keys[left]); - if (keys[left] > pivot) { - right--; - std::swap(keys[left], keys[right]); - if(indexes) std::swap(indexes[left], indexes[right]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + if (indexes) std::swap(indexes[left], indexes[right]); } else { ++left; @@ -304,29 +318,30 @@ static inline int64_t partition_avx512(type_t *keys, zmm_t pivot_vec = vtype::set1(pivot); zmm_t min_vec = vtype::set1(*smallest); zmm_t max_vec = vtype::set1(*biggest); - + if (right - left == vtype::numlanes) { zmm_t keys_vec = vtype::loadu(keys + left); int32_t amount_gt_pivot; - if(indexes) { - index_t indexes_vec = index_type::loadu(indexes + left); - amount_gt_pivot = partition_vec(keys, - indexes, - left, - left + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - }else{ - amount_gt_pivot = partition_vec(keys, - left, - left + vtype::numlanes, - keys_vec, - pivot_vec, - &min_vec, - &max_vec); + if (indexes) { + index_t indexes_vec = index_type::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + } + else { + amount_gt_pivot = partition_vec(keys, + left, + left + vtype::numlanes, + keys_vec, + pivot_vec, + &min_vec, + &max_vec); } *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); @@ -337,13 +352,13 @@ static inline int64_t partition_avx512(type_t *keys, zmm_t keys_vec_left = vtype::loadu(keys + left); zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); index_t indexes_vec_left; - index_t indexes_vec_right; - if(indexes){ - indexes_vec_left = index_type::loadu(indexes + left); - indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); + index_t indexes_vec_right; + if (indexes) { + indexes_vec_left = index_type::loadu(indexes + left); + indexes_vec_right + = index_type::loadu(indexes + (right - vtype::numlanes)); } - // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; @@ -361,32 +376,33 @@ static inline int64_t partition_avx512(type_t *keys, if ((r_store + vtype::numlanes) - right < left - l_store) { right -= vtype::numlanes; keys_vec = vtype::loadu(keys + right); - if(indexes) indexes_vec = index_type::loadu(indexes + right); + if (indexes) indexes_vec = index_type::loadu(indexes + right); } else { keys_vec = vtype::loadu(keys + left); - if(indexes) indexes_vec = index_type::loadu(indexes + left); + if (indexes) indexes_vec = index_type::loadu(indexes + left); left += vtype::numlanes; } // partition the current vector and save it on both sides of the array int32_t amount_gt_pivot; - if(indexes) - amount_gt_pivot= partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - else amount_gt_pivot= partition_vec(keys, - l_store, - r_store + vtype::numlanes, - keys_vec, - pivot_vec, - &min_vec, - &max_vec); + if (indexes) + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + else + amount_gt_pivot = partition_vec(keys, + l_store, + r_store + vtype::numlanes, + keys_vec, + pivot_vec, + &min_vec, + &max_vec); r_store -= amount_gt_pivot; l_store += (vtype::numlanes - amount_gt_pivot); @@ -394,43 +410,44 @@ static inline int64_t partition_avx512(type_t *keys, /* partition and save vec_left and vec_right */ int32_t amount_gt_pivot; - if(indexes){ + if (indexes) { amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - indexes_vec_left, - pivot_vec, - &min_vec, - &max_vec); + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - indexes_vec_right, - pivot_vec, - &min_vec, - &max_vec); - }else{ + indexes, + l_store, + l_store + vtype::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + } + else { amount_gt_pivot = partition_vec(keys, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - pivot_vec, - &min_vec, - &max_vec); + l_store, + r_store + vtype::numlanes, + keys_vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); amount_gt_pivot = partition_vec(keys, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - pivot_vec, - &min_vec, - &max_vec); - } + l_store, + l_store + vtype::numlanes, + keys_vec_right, + pivot_vec, + &min_vec, + &max_vec); + } l_store += (vtype::numlanes - amount_gt_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); diff --git a/tests/meson.build b/tests/meson.build index 7d51ba26..40cd4685 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -1,19 +1,15 @@ libtests = [] -if cc.has_argument('-march=icelake-client') - libtests += static_library( - 'tests_', - files( - 'test_all.cpp', - ), - dependencies : gtest_dep, - include_directories : [ - src, - utils, - ], - cpp_args : [ - '-O3', - '-march=icelake-client', - ], - ) -endif + if cc.has_argument('-march=icelake-client') libtests + += static_library('tests_', files('test_all.cpp', ), dependencies + : gtest_dep, include_directories + : + [ + src, + utils, + ], + cpp_args + : [ + '-O3', + '-march=icelake-client', + ], ) endif diff --git a/tests/test_all.cpp b/tests/test_all.cpp index 6309c7bc..90219d10 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -34,7 +34,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) sortedarr = arr; /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); - avx512_qsort(arr.data(),NULL, arr.size()); + avx512_qsort(arr.data(), NULL, arr.size()); ASSERT_EQ(sortedarr, arr); arr.clear(); sortedarr.clear(); @@ -56,3 +56,47 @@ using Types = testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_sort, Types); + +struct sorted_t { + uint64_t key; + uint64_t value; +}; + +bool compare(sorted_t a, sorted_t b) +{ + return a.key == b.key ? a.value < b.value : a.key < b.key; +} +TEST(TestKeyValueSort, KeyValueSort) +{ + std::vector keysizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + keysizes.push_back((uint64_t)ii); + } + std::vector keys; + std::vector values; + std::vector sortedarr; + + for (size_t ii = 0; ii < keysizes.size(); ++ii) { + /* Random array */ + keys = get_uniform_rand_array_key(keysizes[ii]); + //keys = get_uniform_rand_array(keysizes[ii]); + values = get_uniform_rand_array(keysizes[ii]); + for (size_t i = 0; i < keys.size(); i++) { + sorted_t tmp_s; + tmp_s.key = keys[i]; + tmp_s.value = values[i]; + sortedarr.emplace_back(tmp_s); + } + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end(), compare); + avx512_qsort(keys.data(), values.data(), keys.size()); + //ASSERT_EQ(sortedarr, arr); + for (size_t i = 0; i < keys.size(); i++) { + ASSERT_EQ(keys[i], sortedarr[i].key); + ASSERT_EQ(values[i], sortedarr[i].value); + } + keys.clear(); + values.clear(); + sortedarr.clear(); + } +} diff --git a/utils/rand_array.h b/utils/rand_array.h index 0842a0b4..42e0f99d 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -3,6 +3,7 @@ * * SPDX-License-Identifier: BSD-3-Clause * *******************************************/ +#include #include #include #include @@ -33,10 +34,34 @@ static std::vector get_uniform_rand_array( { std::random_device rd; std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(min, max); + std::uniform_real_distribution dis(min, max); std::vector arr; + //std::cout< +get_uniform_rand_array_key(int64_t arrsize, + uint64_t max = std::numeric_limits::max(), + uint64_t min = std::numeric_limits::min()) +{ + std::vector arr; + std::random_device r; + std::default_random_engine e1(r()); + std::uniform_int_distribution uniform_dist(min, max); + for (int64_t ii = 0; ii < arrsize; ++ii) { + + while (true) { + uint64_t tmp = uniform_dist(e1); + auto iter = std::find(arr.begin(), arr.end(), tmp); + if (iter == arr.end()) { + arr.emplace_back(tmp); + break; + } + } + } + return arr; +} From 7821e013c9ea7b614c7279dcf249c94a3b4951ef Mon Sep 17 00:00:00 2001 From: ruclz Date: Mon, 13 Feb 2023 14:25:48 +0800 Subject: [PATCH 04/16] run clang-format and add unit test for key-value sort --- .vscode/settings.json | 54 ------------------------------------------- 1 file changed, 54 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index f1aaf8a2..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,54 +0,0 @@ -{ - "files.associations": { - "cctype": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "array": "cpp", - "atomic": "cpp", - "bit": "cpp", - "*.tcc": "cpp", - "cstdint": "cpp", - "deque": "cpp", - "map": "cpp", - "set": "cpp", - "unordered_map": "cpp", - "vector": "cpp", - "exception": "cpp", - "algorithm": "cpp", - "functional": "cpp", - "iterator": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "numeric": "cpp", - "optional": "cpp", - "random": "cpp", - "string": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "fstream": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "istream": "cpp", - "limits": "cpp", - "new": "cpp", - "ostream": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "streambuf": "cpp", - "cinttypes": "cpp", - "typeinfo": "cpp" - } -} \ No newline at end of file From ae02b58116d454cc0bdaf041a63eb49b411c7d5c Mon Sep 17 00:00:00 2001 From: ruclz Date: Mon, 13 Feb 2023 14:41:25 +0800 Subject: [PATCH 05/16] add unit test and run clang_format --- .gitignore | 35 ++ .vscode/settings.json | 57 --- Makefile | 14 +- src/avx512-64bit-qsort.hpp | 930 +++++++++++++++++++++---------------- src/avx512-common-qsort.h | 209 +++++---- tests/test_all.cpp | 46 +- utils/rand_array.h | 27 +- 7 files changed, 766 insertions(+), 552 deletions(-) create mode 100644 .gitignore delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..669f6732 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +**/.vscode + diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 5abdc8c3..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "files.associations": { - "*.tcc": "cpp", - "functional": "cpp", - "string_view": "cpp", - "random": "cpp", - "istream": "cpp", - "limits": "cpp", - "algorithm": "cpp", - "bit": "cpp", - "numeric": "cpp", - "cctype": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "array": "cpp", - "atomic": "cpp", - "cstdint": "cpp", - "deque": "cpp", - "unordered_map": "cpp", - "vector": "cpp", - "exception": "cpp", - "iterator": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "optional": "cpp", - "string": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "fstream": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "new": "cpp", - "ostream": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "streambuf": "cpp", - "cinttypes": "cpp", - "typeinfo": "cpp", - "compare": "cpp", - "concepts": "cpp", - "numbers": "cpp", - "map": "cpp", - "set": "cpp" - } -} \ No newline at end of file diff --git a/Makefile b/Makefile index 04bf5d78..899463a3 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,5 @@ -CXX ?= g++ -SRCDIR = ./src -TESTDIR = ./tests -BENCHDIR = ./benchmarks -UTILS = ./utils -SRCS = $(wildcard $(SRCDIR)/*.hpp) +CXX ? = g++ SRCDIR =./ src TESTDIR =./ tests BENCHDIR =./ benchmarks UTILS + =./ utils SRCS = $(wildcard $(SRCDIR)/*.hpp) TESTS = $(wildcard $(TESTDIR)/*.cpp) TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS)) TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS)) @@ -15,13 +11,13 @@ LD_FLAGS = -L /usr/local/lib -l $(GTEST_LIB) -l pthread all : test bench $(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS) - $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -l $(GTEST_LIB) -c $< -o $@ + $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -c $< -o $@ test: $(TESTDIR)/main.cpp $(TESTOBJS) $(SRCS) - $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe + $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe bench: $(BENCHDIR)/main.cpp $(SRCS) $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe clean: - rm -f $(TESTDIR)/*.o testexe benchexe + rm -f $(TESTDIR)/*.o testexe benchexe \ No newline at end of file diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index b9fc061e..e05a54b3 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -364,33 +364,52 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) return zmm; } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_64bit(zmm_t key_zmm,index_t &index_zmm) +template ::zmm_t> +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_t &index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); key_zmm = cmp_merge( - key_zmm,vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), - index_zmm,zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), + index_zmm), 0xCC); key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); key_zmm = cmp_merge( - key_zmm, vtype::permutexvar(rev_index, key_zmm), - index_zmm,zmm_vector::permutexvar(rev_index, index_zmm), + key_zmm, + vtype::permutexvar(rev_index, key_zmm), + index_zmm, + zmm_vector::permutexvar(rev_index, index_zmm), 0xF0); key_zmm = cmp_merge( - key_zmm,vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), - index_zmm,zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), 0xCC); key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); return key_zmm; } @@ -415,8 +434,11 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) return zmm; } // Assumes zmm is bitonic and performs a recursive half cleaner -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm,zmm_vector::zmm_t &index_zmm) +template ::zmm_t> +X86_SIMD_SORT_INLINE zmm_t +bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) { // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 @@ -424,19 +446,24 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm,zmm_vector key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), + index_zmm), 0xF0); // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), 0xCC); // 3) half_cleaner[1] key_zmm = cmp_merge( - key_zmm, vtype::template shuffle(key_zmm), - index_zmm, zmm_vector::template shuffle(index_zmm), + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_vector::template shuffle( + index_zmm), 0xAA); return key_zmm; } @@ -454,28 +481,32 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm_t &zmm2) zmm2 = bitonic_merge_zmm_64bit(zmm4); } // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, - zmm_t &key_zmm2, index_t &index_zmm1, - index_t &index_zmm2) +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, + zmm_t &key_zmm2, + index_t &index_zmm1, + index_t &index_zmm2) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network: coex of zmm1 and zmm2 reversed key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); - zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); - index_t index_zmm3=zmm_vector::mask_mov(index_zmm2,vtype::eq(key_zmm3,key_zmm1),index_zmm1); - index_t index_zmm4=zmm_vector::mask_mov(index_zmm1,vtype::eq(key_zmm3,key_zmm1),index_zmm2); + index_t index_zmm3 = zmm_vector::mask_mov( + index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); + index_t index_zmm4 = zmm_vector::mask_mov( + index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3,index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4,index_zmm4); - index_zmm1=index_zmm3; - index_zmm2=index_zmm4; + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + index_zmm1 = index_zmm3; + index_zmm2 = index_zmm4; } // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner @@ -500,54 +531,66 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) zmm[2] = bitonic_merge_zmm_64bit(zmm2); zmm[3] = bitonic_merge_zmm_64bit(zmm3); } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); - index_t index_zmm2r = zmm_vector::permutexvar(rev_index, index_zmm[2]); - index_t index_zmm3r = zmm_vector::permutexvar(rev_index, index_zmm[3]); - + index_t index_zmm2r + = zmm_vector::permutexvar(rev_index, index_zmm[2]); + index_t index_zmm3r + = zmm_vector::permutexvar(rev_index, index_zmm[3]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); - - index_t index_zmm_t1=zmm_vector::mask_mov(index_zmm3r,vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); - index_t index_zmm_m1=zmm_vector::mask_mov(index_zmm[0],vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm3r); - index_t index_zmm_t2=zmm_vector::mask_mov(index_zmm2r,vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); - index_t index_zmm_m2=zmm_vector::mask_mov(index_zmm[1],vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm2r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); + index_t index_zmm_t1 = zmm_vector::mask_mov( + index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_t index_zmm_t2 = zmm_vector::mask_mov( + index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); // 2) Recursive half clearer: 16 zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t3 = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t4 = zmm_vector::permutexvar(rev_index, index_zmm_m1); + index_t index_zmm_t3 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t4 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); - index_t index_zmm0 = zmm_vector::mask_mov(index_zmm_t2,vtype::eq(key_zmm0,key_zmm_t1),index_zmm_t1); - index_t index_zmm1 = zmm_vector::mask_mov(index_zmm_t1,vtype::eq(key_zmm0,key_zmm_t1),index_zmm_t2); - index_t index_zmm2 = zmm_vector::mask_mov(index_zmm_t4,vtype::eq(key_zmm2,key_zmm_t3),index_zmm_t3); - index_t index_zmm3 = zmm_vector::mask_mov(index_zmm_t3,vtype::eq(key_zmm2,key_zmm_t3),index_zmm_t4); - - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0,index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1,index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2,index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3,index_zmm3); - - index_zmm[0]=index_zmm0; - index_zmm[1]=index_zmm1; - index_zmm[2]=index_zmm2; - index_zmm[3]=index_zmm3; + index_t index_zmm0 = zmm_vector::mask_mov( + index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_t index_zmm1 = zmm_vector::mask_mov( + index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_t index_zmm2 = zmm_vector::mask_mov( + index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_t index_zmm3 = zmm_vector::mask_mov( + index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + + index_zmm[0] = index_zmm0; + index_zmm[1] = index_zmm1; + index_zmm[2] = index_zmm2; + index_zmm[3] = index_zmm3; } template X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) @@ -582,19 +625,25 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); - index_t index_zmm4r = zmm_vector::permutexvar(rev_index, index_zmm[4]); - index_t index_zmm5r = zmm_vector::permutexvar(rev_index, index_zmm[5]); - index_t index_zmm6r = zmm_vector::permutexvar(rev_index, index_zmm[6]); - index_t index_zmm7r = zmm_vector::permutexvar(rev_index, index_zmm[7]); - + index_t index_zmm4r + = zmm_vector::permutexvar(rev_index, index_zmm[4]); + index_t index_zmm5r + = zmm_vector::permutexvar(rev_index, index_zmm[5]); + index_t index_zmm6r + = zmm_vector::permutexvar(rev_index, index_zmm[6]); + index_t index_zmm7r + = zmm_vector::permutexvar(rev_index, index_zmm[7]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); @@ -604,56 +653,63 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,inde zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); - - index_t index_zmm_t1= zmm_vector::mask_mov(index_zmm7r, vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); - index_t index_zmm_m1= zmm_vector::mask_mov(index_zmm[0], vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm7r); - index_t index_zmm_t2= zmm_vector::mask_mov(index_zmm6r, vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); - index_t index_zmm_m2= zmm_vector::mask_mov(index_zmm[1], vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm6r); - index_t index_zmm_t3= zmm_vector::mask_mov(index_zmm5r, vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm[2]); - index_t index_zmm_m3= zmm_vector::mask_mov(index_zmm[2], vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm5r); - index_t index_zmm_t4= zmm_vector::mask_mov(index_zmm4r, vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm[3]); - index_t index_zmm_m4= zmm_vector::mask_mov(index_zmm[3], vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm4r); - - + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); + + index_t index_zmm_t1 = zmm_vector::mask_mov( + index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_t index_zmm_t2 = zmm_vector::mask_mov( + index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_t index_zmm_t3 = zmm_vector::mask_mov( + index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_t index_zmm_m3 = zmm_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_t index_zmm_t4 = zmm_vector::mask_mov( + index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_t index_zmm_m4 = zmm_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t5 = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t6 = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t7 = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t8 = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - - COEX(key_zmm_t1, key_zmm_t3,index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4,index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7,index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8,index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2,index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4,index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6,index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8,index_zmm_t7, index_zmm_t8); - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1,index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2,index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3,index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4,index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5,index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6,index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7,index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8,index_zmm_t8); - - index_zmm[0]=index_zmm_t1; - index_zmm[1]=index_zmm_t2; - index_zmm[2]=index_zmm_t3; - index_zmm[3]=index_zmm_t4; - index_zmm[4]=index_zmm_t5; - index_zmm[5]=index_zmm_t6; - index_zmm[6]=index_zmm_t7; - index_zmm[7]=index_zmm_t8; - - + index_t index_zmm_t5 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t6 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t7 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t8 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; } template X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) @@ -728,8 +784,11 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); } -template ::zmm_t> -X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,index_t *index_zmm) +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); @@ -741,15 +800,23 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,in zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); - index_t index_zmm8r = zmm_vector::permutexvar(rev_index, index_zmm[8]); - index_t index_zmm9r = zmm_vector::permutexvar(rev_index, index_zmm[9]); - index_t index_zmm10r = zmm_vector::permutexvar(rev_index, index_zmm[10]); - index_t index_zmm11r = zmm_vector::permutexvar(rev_index, index_zmm[11]); - index_t index_zmm12r = zmm_vector::permutexvar(rev_index, index_zmm[12]); - index_t index_zmm13r = zmm_vector::permutexvar(rev_index, index_zmm[13]); - index_t index_zmm14r = zmm_vector::permutexvar(rev_index, index_zmm[14]); - index_t index_zmm15r = zmm_vector::permutexvar(rev_index, index_zmm[15]); - + index_t index_zmm8r + = zmm_vector::permutexvar(rev_index, index_zmm[8]); + index_t index_zmm9r + = zmm_vector::permutexvar(rev_index, index_zmm[9]); + index_t index_zmm10r + = zmm_vector::permutexvar(rev_index, index_zmm[10]); + index_t index_zmm11r + = zmm_vector::permutexvar(rev_index, index_zmm[11]); + index_t index_zmm12r + = zmm_vector::permutexvar(rev_index, index_zmm[12]); + index_t index_zmm13r + = zmm_vector::permutexvar(rev_index, index_zmm[13]); + index_t index_zmm14r + = zmm_vector::permutexvar(rev_index, index_zmm[14]); + index_t index_zmm15r + = zmm_vector::permutexvar(rev_index, index_zmm[15]); + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); @@ -766,26 +833,41 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,in zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); - - index_t index_zmm_t1=zmm_vector::mask_mov(index_zmm15r,vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm[0]); - index_t index_zmm_m1=zmm_vector::mask_mov(index_zmm[0],vtype::eq(key_zmm_t1,key_zmm[0]),index_zmm15r); - index_t index_zmm_t2=zmm_vector::mask_mov(index_zmm14r,vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm[1]); - index_t index_zmm_m2=zmm_vector::mask_mov(index_zmm[1],vtype::eq(key_zmm_t2,key_zmm[1]),index_zmm14r); - index_t index_zmm_t3=zmm_vector::mask_mov(index_zmm13r,vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm[2]); - index_t index_zmm_m3=zmm_vector::mask_mov(index_zmm[2],vtype::eq(key_zmm_t3,key_zmm[2]),index_zmm13r); - index_t index_zmm_t4=zmm_vector::mask_mov(index_zmm12r,vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm[3]); - index_t index_zmm_m4=zmm_vector::mask_mov(index_zmm[3],vtype::eq(key_zmm_t4,key_zmm[3]),index_zmm12r); - - index_t index_zmm_t5=zmm_vector::mask_mov(index_zmm11r,vtype::eq(key_zmm_t5,key_zmm[4]),index_zmm[4]); - index_t index_zmm_m5=zmm_vector::mask_mov(index_zmm[4],vtype::eq(key_zmm_t5,key_zmm[4]),index_zmm11r); - index_t index_zmm_t6=zmm_vector::mask_mov(index_zmm10r,vtype::eq(key_zmm_t6,key_zmm[5]),index_zmm[5]); - index_t index_zmm_m6=zmm_vector::mask_mov(index_zmm[5],vtype::eq(key_zmm_t6,key_zmm[5]),index_zmm10r); - index_t index_zmm_t7=zmm_vector::mask_mov(index_zmm9r,vtype::eq(key_zmm_t7,key_zmm[6]),index_zmm[6]); - index_t index_zmm_m7=zmm_vector::mask_mov(index_zmm[6],vtype::eq(key_zmm_t7,key_zmm[6]),index_zmm9r); - index_t index_zmm_t8=zmm_vector::mask_mov(index_zmm8r,vtype::eq(key_zmm_t8,key_zmm[7]),index_zmm[7]); - index_t index_zmm_m8=zmm_vector::mask_mov(index_zmm[7],vtype::eq(key_zmm_t8,key_zmm[7]),index_zmm8r); - + zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); + + index_t index_zmm_t1 = zmm_vector::mask_mov( + index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_t index_zmm_t2 = zmm_vector::mask_mov( + index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_t index_zmm_t3 = zmm_vector::mask_mov( + index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_t index_zmm_m3 = zmm_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_t index_zmm_t4 = zmm_vector::mask_mov( + index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_t index_zmm_m4 = zmm_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_t index_zmm_t5 = zmm_vector::mask_mov( + index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_t index_zmm_m5 = zmm_vector::mask_mov( + index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_t index_zmm_t6 = zmm_vector::mask_mov( + index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_t index_zmm_m6 = zmm_vector::mask_mov( + index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_t index_zmm_t7 = zmm_vector::mask_mov( + index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_t index_zmm_m7 = zmm_vector::mask_mov( + index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_t index_zmm_t8 = zmm_vector::mask_mov( + index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_t index_zmm_m8 = zmm_vector::mask_mov( + index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); @@ -795,98 +877,110 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm,in zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t9 = zmm_vector::permutexvar(rev_index, index_zmm_m8); - index_t index_zmm_t10 = zmm_vector::permutexvar(rev_index, index_zmm_m7); - index_t index_zmm_t11 = zmm_vector::permutexvar(rev_index, index_zmm_m6); - index_t index_zmm_t12 = zmm_vector::permutexvar(rev_index, index_zmm_m5); - index_t index_zmm_t13 = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t14 = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t15 = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t16 = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5,index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6,index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7,index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8,index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13,index_zmm_t9, index_zmm_t13); - COEX(key_zmm_t10, key_zmm_t14,index_zmm_t10, index_zmm_t14); - COEX(key_zmm_t11, key_zmm_t15,index_zmm_t11, index_zmm_t15); - COEX(key_zmm_t12, key_zmm_t16,index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3,index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4,index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7,index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8,index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11,index_zmm_t9, index_zmm_t11); - COEX(key_zmm_t10, key_zmm_t12,index_zmm_t10, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t15,index_zmm_t13, index_zmm_t15); - COEX(key_zmm_t14, key_zmm_t16,index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2,index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4,index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6,index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8,index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10,index_zmm_t9, index_zmm_t10); - COEX(key_zmm_t11, key_zmm_t12,index_zmm_t11, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t14,index_zmm_t13, index_zmm_t14); - COEX(key_zmm_t15, key_zmm_t16,index_zmm_t15, index_zmm_t16); + index_t index_zmm_t9 + = zmm_vector::permutexvar(rev_index, index_zmm_m8); + index_t index_zmm_t10 + = zmm_vector::permutexvar(rev_index, index_zmm_m7); + index_t index_zmm_t11 + = zmm_vector::permutexvar(rev_index, index_zmm_m6); + index_t index_zmm_t12 + = zmm_vector::permutexvar(rev_index, index_zmm_m5); + index_t index_zmm_t13 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t14 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t15 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t16 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); // - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1,index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2,index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3,index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4,index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5,index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6,index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7,index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8,index_zmm_t8); - key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9,index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10,index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11,index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12,index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13,index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14,index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15,index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16,index_zmm_t16); - - index_zmm[0]=index_zmm_t1; - index_zmm[1]=index_zmm_t2; - index_zmm[2]=index_zmm_t3; - index_zmm[3]=index_zmm_t4; - index_zmm[4]=index_zmm_t5; - index_zmm[5]=index_zmm_t6; - index_zmm[6]=index_zmm_t7; - index_zmm[7]=index_zmm_t8; - index_zmm[8]=index_zmm_t9; - index_zmm[9]=index_zmm_t10; - index_zmm[10]=index_zmm_t11; - index_zmm[11]=index_zmm_t12; - index_zmm[12]=index_zmm_t13; - index_zmm[13]=index_zmm_t14; - index_zmm[14]=index_zmm_t15; - index_zmm[15]=index_zmm_t16; + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; + index_zmm[8] = index_zmm_t9; + index_zmm[9] = index_zmm_t10; + index_zmm[10] = index_zmm_t11; + index_zmm[11] = index_zmm_t12; + index_zmm[12] = index_zmm_t13; + index_zmm[13] = index_zmm_t14; + index_zmm[14] = index_zmm_t15; + index_zmm[15] = index_zmm_t16; } template -X86_SIMD_SORT_FORCEINLINE void sort_8_64bit(type_t *keys,uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void +sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) { typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; typename vtype::zmm_t key_zmm = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); - if(indexes){ + if (indexes) { zmm_vector::zmm_t index_zmm - = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask, indexes); - vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm,index_zmm)); - zmm_vector::mask_storeu(indexes, load_mask,index_zmm); - }else{ + = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes); + vtype::mask_storeu( + keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); + zmm_vector::mask_storeu(indexes, load_mask, index_zmm); + } + else { vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm)); - } - + } } template -X86_SIMD_SORT_FORCEINLINE void sort_16_64bit(type_t *keys,uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void +sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 8) { - sort_8_64bit(keys,indexes, N); + sort_8_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; @@ -895,26 +989,30 @@ X86_SIMD_SORT_FORCEINLINE void sort_16_64bit(type_t *keys,uint64_t *indexes, int typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); - if(indexes){ + if (indexes) { index_t index_zmm1 = zmm_vector::loadu(indexes); - index_t index_zmm2 = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask, indexes + 8); - key_zmm1 = sort_zmm_64bit(key_zmm1,index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2,index_zmm2); - bitonic_merge_two_zmm_64bit(key_zmm1,key_zmm2,index_zmm1,index_zmm2); + index_t index_zmm2 = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes + 8); + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( + key_zmm1, key_zmm2, index_zmm1, index_zmm2); zmm_vector::storeu(indexes, index_zmm1); zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); - }else{ + } + else { key_zmm1 = sort_zmm_64bit(key_zmm1); key_zmm2 = sort_zmm_64bit(key_zmm2); - bitonic_merge_two_zmm_64bit(key_zmm1,key_zmm2); + bitonic_merge_two_zmm_64bit(key_zmm1, key_zmm2); } - + vtype::storeu(keys, key_zmm1); vtype::mask_storeu(keys + 8, load_mask, key_zmm2); } template -X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void +sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 16) { sort_16_64bit(keys, indexes, N); @@ -925,15 +1023,16 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int using index_t = zmm_vector::zmm_t; zmm_t key_zmm[4]; index_t index_zmm[4]; - + key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); - if(indexes){ + if (indexes) { index_zmm[0] = zmm_vector::loadu(indexes); index_zmm[1] = zmm_vector::loadu(indexes + 8); - key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); - }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + } + else { key_zmm[0] = sort_zmm_64bit(key_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1]); } @@ -944,24 +1043,31 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); - if(indexes){ - index_zmm[2] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); + if (indexes) { + index_zmm[2] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 24); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); zmm_vector::storeu(indexes, index_zmm[0]); zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - }else{ + zmm_vector::mask_storeu( + indexes + 16, load_mask1, index_zmm[2]); + zmm_vector::mask_storeu( + indexes + 24, load_mask2, index_zmm[3]); + } + else { key_zmm[2] = sort_zmm_64bit(key_zmm[2]); key_zmm[3] = sort_zmm_64bit(key_zmm[3]); bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm); + bitonic_merge_four_zmm_64bit(key_zmm); } vtype::storeu(keys, key_zmm[0]); vtype::storeu(keys + 8, key_zmm[1]); @@ -970,10 +1076,11 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_64bit(type_t *keys,uint64_t *indexes, int } template -X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void +sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 32) { - sort_32_64bit(keys,indexes, N); + sort_32_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; @@ -981,21 +1088,22 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in using index_t = zmm_vector::zmm_t; zmm_t key_zmm[8]; index_t index_zmm[8]; - + key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); key_zmm[2] = vtype::loadu(keys + 16); key_zmm[3] = vtype::loadu(keys + 24); - if(indexes){ + if (indexes) { index_zmm[0] = zmm_vector::loadu(indexes); index_zmm[1] = zmm_vector::loadu(indexes + 8); index_zmm[2] = zmm_vector::loadu(indexes + 16); index_zmm[3] = zmm_vector::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); - }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + } + else { key_zmm[0] = sort_zmm_64bit(key_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1]); key_zmm[2] = sort_zmm_64bit(key_zmm[2]); @@ -1014,31 +1122,44 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); - if(indexes){ - index_zmm[4] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4],index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5],index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6],index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7],index_zmm[7]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5],index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7],index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4,index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm,index_zmm); + if (indexes) { + index_zmm[4] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); zmm_vector::storeu(indexes, index_zmm[0]); zmm_vector::storeu(indexes + 8, index_zmm[1]); zmm_vector::storeu(indexes + 16, index_zmm[2]); zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - }else{ + zmm_vector::mask_storeu( + indexes + 32, load_mask1, index_zmm[4]); + zmm_vector::mask_storeu( + indexes + 40, load_mask2, index_zmm[5]); + zmm_vector::mask_storeu( + indexes + 48, load_mask3, index_zmm[6]); + zmm_vector::mask_storeu( + indexes + 56, load_mask4, index_zmm[7]); + } + else { key_zmm[4] = sort_zmm_64bit(key_zmm[4]); key_zmm[5] = sort_zmm_64bit(key_zmm[5]); key_zmm[6] = sort_zmm_64bit(key_zmm[6]); @@ -1049,7 +1170,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7]); bitonic_merge_four_zmm_64bit(key_zmm); bitonic_merge_four_zmm_64bit(key_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm); } vtype::storeu(keys, key_zmm[0]); @@ -1063,10 +1184,11 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_64bit(type_t *keys, uint64_t *indexes, in } template -X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void +sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) { if (N <= 64) { - sort_64_64bit(keys,indexes, N); + sort_64_64bit(keys, indexes, N); return; } using zmm_t = typename vtype::zmm_t; @@ -1074,7 +1196,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i using opmask_t = typename vtype::opmask_t; zmm_t key_zmm[16]; index_t index_zmm[16]; - + key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); key_zmm[2] = vtype::loadu(keys + 16); @@ -1083,7 +1205,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i key_zmm[5] = vtype::loadu(keys + 40); key_zmm[6] = vtype::loadu(keys + 48); key_zmm[7] = vtype::loadu(keys + 56); - if(indexes!=NULL){ + if (indexes != NULL) { index_zmm[0] = zmm_vector::loadu(indexes); index_zmm[1] = zmm_vector::loadu(indexes + 8); index_zmm[2] = zmm_vector::loadu(indexes + 16); @@ -1092,15 +1214,16 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i index_zmm[5] = zmm_vector::loadu(indexes + 40); index_zmm[6] = zmm_vector::loadu(indexes + 48); index_zmm[7] = zmm_vector::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0],index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1],index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2],index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3],index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4],index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5],index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6],index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7],index_zmm[7]); - }else{ + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + } + else { key_zmm[0] = sort_zmm_64bit(key_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1]); key_zmm[2] = sort_zmm_64bit(key_zmm[2]); @@ -1135,39 +1258,55 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); - if(indexes){ - index_zmm[8] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = zmm_vector::mask_loadu(zmm_vector::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8],index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9],index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10],index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11],index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12],index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13],index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14],index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15],index_zmm[15]); - - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1],index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3],index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5],index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7],index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit(key_zmm[8], key_zmm[9],index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit(key_zmm[10], key_zmm[11],index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit(key_zmm[12], key_zmm[13],index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit(key_zmm[14], key_zmm[15],index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm,index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4,index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8,index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12,index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm,index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8,index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm,index_zmm); + if (indexes) { + index_zmm[8] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); + bitonic_merge_two_zmm_64bit( + key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); + bitonic_merge_two_zmm_64bit( + key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); + bitonic_merge_two_zmm_64bit( + key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); zmm_vector::storeu(indexes, index_zmm[0]); zmm_vector::storeu(indexes + 8, index_zmm[1]); zmm_vector::storeu(indexes + 16, index_zmm[2]); @@ -1176,15 +1315,24 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i zmm_vector::storeu(indexes + 40, index_zmm[5]); zmm_vector::storeu(indexes + 48, index_zmm[6]); zmm_vector::storeu(indexes + 56, index_zmm[7]); - zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - }else{ + zmm_vector::mask_storeu( + indexes + 64, load_mask1, index_zmm[8]); + zmm_vector::mask_storeu( + indexes + 72, load_mask2, index_zmm[9]); + zmm_vector::mask_storeu( + indexes + 80, load_mask3, index_zmm[10]); + zmm_vector::mask_storeu( + indexes + 88, load_mask4, index_zmm[11]); + zmm_vector::mask_storeu( + indexes + 96, load_mask5, index_zmm[12]); + zmm_vector::mask_storeu( + indexes + 104, load_mask6, index_zmm[13]); + zmm_vector::mask_storeu( + indexes + 112, load_mask7, index_zmm[14]); + zmm_vector::mask_storeu( + indexes + 120, load_mask8, index_zmm[15]); + } + else { key_zmm[8] = sort_zmm_64bit(key_zmm[8]); key_zmm[9] = sort_zmm_64bit(key_zmm[9]); key_zmm[10] = sort_zmm_64bit(key_zmm[10]); @@ -1225,13 +1373,13 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_64bit(type_t *keys, uint64_t *indexes, i vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); - } template -X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *keys, uint64_t *indexes, - const int64_t left, - const int64_t right) +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, + uint64_t *indexes, + const int64_t left, + const int64_t right) { // median of 8 int64_t size = (right - left) / 8; @@ -1249,85 +1397,88 @@ X86_SIMD_SORT_FORCEINLINE type_t get_pivot_64bit(type_t *keys, uint64_t *indexes index_t index_vec; zmm_t sort; - if(indexes) - { - index_vec=zmm_vector::template i64gather(rand_index, indexes); - sort = sort_zmm_64bit(key_vec,index_vec); - }else{ - //index_vec=vtype::template i64gather(rand_index, indexes); - sort = sort_zmm_64bit(key_vec); + if (indexes) { + index_vec = zmm_vector::template i64gather( + rand_index, indexes); + sort = sort_zmm_64bit(key_vec, index_vec); + } + else { + //index_vec=vtype::template i64gather(rand_index, indexes); + sort = sort_zmm_64bit(key_vec); } // pivot will never be a nan, since there are no nan's! - + return ((type_t *)&sort)[4]; } template -inline void -heapify(type_t* keys, uint64_t* indexes, int64_t idx, int64_t size) +inline void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) { - int64_t i = idx; - while(true) { - int64_t j = 2 * i + 1; - if (j >= size || j < 0) { - break; - } - int k = j + 1; - if (k < size && keys[j] < keys[k]) { - j = k; - } - if (keys[j] < keys[i]) { - break; - } - std::swap(keys[i], keys[j]); - std::swap(indexes[i], indexes[j]); - i = j; - } + int64_t i = idx; + while (true) { + int64_t j = 2 * i + 1; + if (j >= size || j < 0) { break; } + int k = j + 1; + if (k < size && keys[j] < keys[k]) { j = k; } + if (keys[j] < keys[i]) { break; } + std::swap(keys[i], keys[j]); + std::swap(indexes[i], indexes[j]); + i = j; + } } template -inline void -heap_sort(type_t* keys, uint64_t* indexes, int64_t size) +inline void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) { - for (int64_t i = size / 2 - 1; i >= 0; i--) { - heapify(keys, indexes, i, size); - } - for (int64_t i = size - 1; i > 0; i--) { - std::swap(keys[0], keys[i]); - std::swap(indexes[0], indexes[i]); - heapify(keys, indexes, 0, i); - } + for (int64_t i = size / 2 - 1; i >= 0; i--) { + heapify(keys, indexes, i, size); + } + for (int64_t i = size - 1; i > 0; i--) { + std::swap(keys[0], keys[i]); + std::swap(indexes[0], indexes[i]); + heapify(keys, indexes, 0, i); + } } template -inline void -qsort_64bit_(type_t *keys,uint64_t *indexes, int64_t left, int64_t right, int64_t max_iters) +inline void qsort_64bit_(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - if(indexes)heap_sort(keys + left, indexes + left, right - left + 1); - else std::sort(keys + left, keys + right + 1); + if (indexes) + heap_sort(keys + left, indexes + left, right - left + 1); + else + std::sort(keys + left, keys + right + 1); return; } /* * Base case: use bitonic networks to sort arrays <= 128 */ if (right + 1 - left <= 128) { - if(indexes) sort_128_64bit(keys + left, indexes + left, (int32_t)(right + 1 - left)); - else sort_128_64bit(keys + left, (uint64_t*)NULL, (int32_t)(right + 1 - left)); + if (indexes) + sort_128_64bit( + keys + left, indexes + left, (int32_t)(right + 1 - left)); + else + sort_128_64bit( + keys + left, (uint64_t *)NULL, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot_64bit(keys, indexes,left, right); + type_t pivot = get_pivot_64bit(keys, indexes, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - qsort_64bit_(keys,indexes, left, pivot_index - 1, max_iters - 1); + qsort_64bit_( + keys, indexes, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - qsort_64bit_(keys,indexes, pivot_index, right, max_iters - 1); + qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); } X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) @@ -1356,30 +1507,33 @@ replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int64_t *keys, uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(int64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, int64_t>( - keys,indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -inline void avx512_qsort(uint64_t *keys,uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(uint64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, uint64_t>( - keys,indexes, 0, arrsize - 1,2 * (int64_t)log2(arrsize)); + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -inline void avx512_qsort(double *keys,uint64_t *indexes, int64_t arrsize) +inline void +avx512_qsort(double *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(keys, arrsize); qsort_64bit_, double>( - keys,indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(keys, arrsize, nan_count); } } diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 89ac985a..99ed4076 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -84,7 +84,6 @@ template struct zmm_vector; - template inline void avx512_qsort(T *keys, uint64_t *indexes, int64_t arrsize); @@ -107,19 +106,24 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } -template > -static void COEX(mm_t &key1, mm_t &key2,index_t &index1, index_t &index2) +template > +static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) { //COEX(key1,key2); - mm_t key_t1=vtype::min(key1,key2); - mm_t key_t2=vtype::max(key1,key2); + mm_t key_t1 = vtype::min(key1, key2); + mm_t key_t2 = vtype::max(key1, key2); - index_t index_t1=index_type::mask_mov(index2,vtype::eq(key_t1,key1),index1); - index_t index_t2=index_type::mask_mov(index1,vtype::eq(key_t1,key1),index2); - - key1=key_t1;key2=key_t2; - index1=index_t1;index2=index_t2; + index_t index_t1 + = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); + index_t index_t2 + = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); + key1 = key_t1; + key2 = key_t2; + index1 = index_t1; + index2 = index_t2; } template > -static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2,index_t & indexes1,index_t indexes2, opmask_t mask) + typename index_type = zmm_vector> +static inline zmm_t cmp_merge(zmm_t in1, + zmm_t in2, + index_t &indexes1, + index_t indexes2, + opmask_t mask) { - zmm_t tmp_keys=cmp_merge(in1,in2,mask); - indexes1=index_type::mask_mov(indexes2,vtype::eq(tmp_keys, in1),indexes1); + zmm_t tmp_keys = cmp_merge(in1, in2, mask); + indexes1 = index_type::mask_mov( + indexes2, vtype::eq(tmp_keys, in1), indexes1); return tmp_keys; // 0 -> min, 1 -> max } /* @@ -164,7 +173,10 @@ static inline int32_t partition_vec(type_t *arr, *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; } -template > +template > static inline int32_t partition_vec(type_t *keys, uint64_t *indexes, int64_t left, @@ -183,7 +195,7 @@ static inline int32_t partition_vec(type_t *keys, vtype::mask_compressstoreu( keys + right - amount_gt_pivot, gt_mask, keys_vec); index_type::mask_compressstoreu( - indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); + indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); index_type::mask_compressstoreu( indexes + right - amount_gt_pivot, gt_mask, indexes_vec); *smallest_vec = vtype::min(keys_vec, *smallest_vec); @@ -296,7 +308,9 @@ static inline int64_t partition_avx512(type_t *arr, return l_store; } -template > +template > static inline int64_t partition_avx512(type_t *keys, uint64_t *indexes, int64_t left, @@ -309,10 +323,10 @@ static inline int64_t partition_avx512(type_t *keys, for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { *smallest = std::min(*smallest, keys[left]); *biggest = std::max(*biggest, keys[left]); - if (keys[left] > pivot) { - right--; - std::swap(keys[left], keys[right]); - if(indexes) std::swap(indexes[left], indexes[right]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + if (indexes) std::swap(indexes[left], indexes[right]); } else { ++left; @@ -326,29 +340,30 @@ static inline int64_t partition_avx512(type_t *keys, zmm_t pivot_vec = vtype::set1(pivot); zmm_t min_vec = vtype::set1(*smallest); zmm_t max_vec = vtype::set1(*biggest); - + if (right - left == vtype::numlanes) { zmm_t keys_vec = vtype::loadu(keys + left); int32_t amount_gt_pivot; - if(indexes) { - index_t indexes_vec = index_type::loadu(indexes + left); - amount_gt_pivot = partition_vec(keys, - indexes, - left, - left + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - }else{ - amount_gt_pivot = partition_vec(keys, - left, - left + vtype::numlanes, - keys_vec, - pivot_vec, - &min_vec, - &max_vec); + if (indexes) { + index_t indexes_vec = index_type::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + } + else { + amount_gt_pivot = partition_vec(keys, + left, + left + vtype::numlanes, + keys_vec, + pivot_vec, + &min_vec, + &max_vec); } *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); @@ -359,13 +374,13 @@ static inline int64_t partition_avx512(type_t *keys, zmm_t keys_vec_left = vtype::loadu(keys + left); zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); index_t indexes_vec_left; - index_t indexes_vec_right; - if(indexes){ - indexes_vec_left = index_type::loadu(indexes + left); - indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); + index_t indexes_vec_right; + if (indexes) { + indexes_vec_left = index_type::loadu(indexes + left); + indexes_vec_right + = index_type::loadu(indexes + (right - vtype::numlanes)); } - // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; @@ -383,32 +398,33 @@ static inline int64_t partition_avx512(type_t *keys, if ((r_store + vtype::numlanes) - right < left - l_store) { right -= vtype::numlanes; keys_vec = vtype::loadu(keys + right); - if(indexes) indexes_vec = index_type::loadu(indexes + right); + if (indexes) indexes_vec = index_type::loadu(indexes + right); } else { keys_vec = vtype::loadu(keys + left); - if(indexes) indexes_vec = index_type::loadu(indexes + left); + if (indexes) indexes_vec = index_type::loadu(indexes + left); left += vtype::numlanes; } // partition the current vector and save it on both sides of the array int32_t amount_gt_pivot; - if(indexes) - amount_gt_pivot= partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - else amount_gt_pivot= partition_vec(keys, - l_store, - r_store + vtype::numlanes, - keys_vec, - pivot_vec, - &min_vec, - &max_vec); + if (indexes) + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + else + amount_gt_pivot = partition_vec(keys, + l_store, + r_store + vtype::numlanes, + keys_vec, + pivot_vec, + &min_vec, + &max_vec); r_store -= amount_gt_pivot; l_store += (vtype::numlanes - amount_gt_pivot); @@ -416,43 +432,44 @@ static inline int64_t partition_avx512(type_t *keys, /* partition and save vec_left and vec_right */ int32_t amount_gt_pivot; - if(indexes){ + if (indexes) { amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - indexes_vec_left, - pivot_vec, - &min_vec, - &max_vec); + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - indexes_vec_right, - pivot_vec, - &min_vec, - &max_vec); - }else{ + indexes, + l_store, + l_store + vtype::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + } + else { amount_gt_pivot = partition_vec(keys, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - pivot_vec, - &min_vec, - &max_vec); + l_store, + r_store + vtype::numlanes, + keys_vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); amount_gt_pivot = partition_vec(keys, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - pivot_vec, - &min_vec, - &max_vec); - } + l_store, + l_store + vtype::numlanes, + keys_vec_right, + pivot_vec, + &min_vec, + &max_vec); + } l_store += (vtype::numlanes - amount_gt_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); diff --git a/tests/test_all.cpp b/tests/test_all.cpp index 6309c7bc..90219d10 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -34,7 +34,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) sortedarr = arr; /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); - avx512_qsort(arr.data(),NULL, arr.size()); + avx512_qsort(arr.data(), NULL, arr.size()); ASSERT_EQ(sortedarr, arr); arr.clear(); sortedarr.clear(); @@ -56,3 +56,47 @@ using Types = testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_sort, Types); + +struct sorted_t { + uint64_t key; + uint64_t value; +}; + +bool compare(sorted_t a, sorted_t b) +{ + return a.key == b.key ? a.value < b.value : a.key < b.key; +} +TEST(TestKeyValueSort, KeyValueSort) +{ + std::vector keysizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + keysizes.push_back((uint64_t)ii); + } + std::vector keys; + std::vector values; + std::vector sortedarr; + + for (size_t ii = 0; ii < keysizes.size(); ++ii) { + /* Random array */ + keys = get_uniform_rand_array_key(keysizes[ii]); + //keys = get_uniform_rand_array(keysizes[ii]); + values = get_uniform_rand_array(keysizes[ii]); + for (size_t i = 0; i < keys.size(); i++) { + sorted_t tmp_s; + tmp_s.key = keys[i]; + tmp_s.value = values[i]; + sortedarr.emplace_back(tmp_s); + } + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end(), compare); + avx512_qsort(keys.data(), values.data(), keys.size()); + //ASSERT_EQ(sortedarr, arr); + for (size_t i = 0; i < keys.size(); i++) { + ASSERT_EQ(keys[i], sortedarr[i].key); + ASSERT_EQ(values[i], sortedarr[i].value); + } + keys.clear(); + values.clear(); + sortedarr.clear(); + } +} diff --git a/utils/rand_array.h b/utils/rand_array.h index 0842a0b4..42e0f99d 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -3,6 +3,7 @@ * * SPDX-License-Identifier: BSD-3-Clause * *******************************************/ +#include #include #include #include @@ -33,10 +34,34 @@ static std::vector get_uniform_rand_array( { std::random_device rd; std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(min, max); + std::uniform_real_distribution dis(min, max); std::vector arr; + //std::cout< +get_uniform_rand_array_key(int64_t arrsize, + uint64_t max = std::numeric_limits::max(), + uint64_t min = std::numeric_limits::min()) +{ + std::vector arr; + std::random_device r; + std::default_random_engine e1(r()); + std::uniform_int_distribution uniform_dist(min, max); + for (int64_t ii = 0; ii < arrsize; ++ii) { + + while (true) { + uint64_t tmp = uniform_dist(e1); + auto iter = std::find(arr.begin(), arr.end(), tmp); + if (iter == arr.end()) { + arr.emplace_back(tmp); + break; + } + } + } + return arr; +} From 5363ba4759ef12669482d72bea1a2fe28f301fd6 Mon Sep 17 00:00:00 2001 From: ruclz <2015202025@ruc.edu.cn> Date: Mon, 13 Feb 2023 14:49:51 +0800 Subject: [PATCH 06/16] Rename Makefile to Makefile.bak2 --- Makefile => Makefile.bak2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename Makefile => Makefile.bak2 (95%) diff --git a/Makefile b/Makefile.bak2 similarity index 95% rename from Makefile rename to Makefile.bak2 index 899463a3..2f80a35d 100644 --- a/Makefile +++ b/Makefile.bak2 @@ -20,4 +20,4 @@ bench: $(BENCHDIR)/main.cpp $(SRCS) $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe clean: - rm -f $(TESTDIR)/*.o testexe benchexe \ No newline at end of file + rm -f $(TESTDIR)/*.o testexe benchexe From cd836c7464a03002dd8d909a7692f0836bf4ff44 Mon Sep 17 00:00:00 2001 From: ruclz <2015202025@ruc.edu.cn> Date: Mon, 13 Feb 2023 14:50:39 +0800 Subject: [PATCH 07/16] Rename Makefile.bak to Makefile --- Makefile.bak => Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename Makefile.bak => Makefile (95%) diff --git a/Makefile.bak b/Makefile similarity index 95% rename from Makefile.bak rename to Makefile index 07c7818d..938dbe5b 100644 --- a/Makefile.bak +++ b/Makefile @@ -24,4 +24,4 @@ bench: $(BENCHDIR)/main.cpp $(SRCS) $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe clean: - rm -f $(TESTDIR)/*.o testexe benchexe \ No newline at end of file + rm -f $(TESTDIR)/*.o testexe benchexe From 541f9143c5d11e0a67325081d70ab1b21eaa5db4 Mon Sep 17 00:00:00 2001 From: ruclz <2015202025@ruc.edu.cn> Date: Mon, 13 Feb 2023 14:51:12 +0800 Subject: [PATCH 08/16] Delete Makefile.bak2 --- Makefile.bak2 | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 Makefile.bak2 diff --git a/Makefile.bak2 b/Makefile.bak2 deleted file mode 100644 index 2f80a35d..00000000 --- a/Makefile.bak2 +++ /dev/null @@ -1,23 +0,0 @@ -CXX ? = g++ SRCDIR =./ src TESTDIR =./ tests BENCHDIR =./ benchmarks UTILS - =./ utils SRCS = $(wildcard $(SRCDIR)/*.hpp) -TESTS = $(wildcard $(TESTDIR)/*.cpp) -TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS)) -TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS)) -GTEST_LIB = gtest -GTEST_INCLUDE = /usr/local/include -CXXFLAGS += -I$(SRCDIR) -I$(GTEST_INCLUDE) -I$(UTILS) -LD_FLAGS = -L /usr/local/lib -l $(GTEST_LIB) -l pthread - -all : test bench - -$(TESTDIR)/%.o : $(TESTDIR)/%.cpp $(SRCS) - $(CXX) -march=icelake-client -O3 $(CXXFLAGS) -c $< -o $@ - -test: $(TESTDIR)/main.cpp $(TESTOBJS) $(SRCS) - $(CXX) tests/main.cpp $(TESTOBJS) $(CXXFLAGS) $(LD_FLAGS) -o testexe - -bench: $(BENCHDIR)/main.cpp $(SRCS) - $(CXX) $(BENCHDIR)/main.cpp $(CXXFLAGS) -march=icelake-client -O3 -o benchexe - -clean: - rm -f $(TESTDIR)/*.o testexe benchexe From 8873ea16fd047997ae2bf766b85c98125077a74a Mon Sep 17 00:00:00 2001 From: ruclz <2015202025@ruc.edu.cn> Date: Mon, 13 Feb 2023 14:52:31 +0800 Subject: [PATCH 09/16] Update meson.build --- tests/meson.build | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/meson.build b/tests/meson.build index 40cd4685..7d51ba26 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -1,15 +1,19 @@ libtests = [] - if cc.has_argument('-march=icelake-client') libtests - += static_library('tests_', files('test_all.cpp', ), dependencies - : gtest_dep, include_directories - : - [ - src, - utils, - ], - cpp_args - : [ - '-O3', - '-march=icelake-client', - ], ) endif +if cc.has_argument('-march=icelake-client') + libtests += static_library( + 'tests_', + files( + 'test_all.cpp', + ), + dependencies : gtest_dep, + include_directories : [ + src, + utils, + ], + cpp_args : [ + '-O3', + '-march=icelake-client', + ], + ) +endif From 7ea8230240a6196877d81ba52e85192e3d52186a Mon Sep 17 00:00:00 2001 From: ruclz Date: Thu, 23 Feb 2023 13:54:18 +0800 Subject: [PATCH 10/16] Creat key-value sort hpp file. --- Makefile | 8 +- benchmarks/bench.hpp | 2 +- src/avx512-16bit-qsort.hpp | 4 +- src/avx512-32bit-qsort.hpp | 6 +- src/avx512-64bit-qsort.hpp | 1015 ++++------------------- src/avx512-common-keyvaluesort.h | 291 +++++++ src/avx512-common-qsort.h | 237 +----- src/avx512_64bit_keyvaluesort.hpp | 1276 +++++++++++++++++++++++++++++ tests/test_all.cpp | 5 +- 9 files changed, 1735 insertions(+), 1109 deletions(-) create mode 100644 src/avx512-common-keyvaluesort.h create mode 100644 src/avx512_64bit_keyvaluesort.hpp diff --git a/Makefile b/Makefile index 899463a3..07c7818d 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,9 @@ -CXX ? = g++ SRCDIR =./ src TESTDIR =./ tests BENCHDIR =./ benchmarks UTILS - =./ utils SRCS = $(wildcard $(SRCDIR)/*.hpp) +CXX ?= g++ +SRCDIR = ./src +TESTDIR = ./tests +BENCHDIR = ./benchmarks +UTILS = ./utils +SRCS = $(wildcard $(SRCDIR)/*.hpp) TESTS = $(wildcard $(TESTDIR)/*.cpp) TESTOBJS = $(patsubst $(TESTDIR)/%.cpp,$(TESTDIR)/%.o,$(TESTS)) TESTOBJS := $(filter-out $(TESTDIR)/main.o ,$(TESTOBJS)) diff --git a/benchmarks/bench.hpp b/benchmarks/bench.hpp index 48fca6cd..0837a6ca 100644 --- a/benchmarks/bench.hpp +++ b/benchmarks/bench.hpp @@ -49,7 +49,7 @@ std::tuple bench_sort(const std::vector arr, uint64_t start(0), end(0); for (uint64_t ii = 0; ii < iters; ++ii) { start = cycles_start(); - avx512_qsort(arr_bckup.data(), NULL, arr_bckup.size()); + avx512_qsort(arr_bckup.data(), arr_bckup.size()); end = cycles_end(); runtimes1.emplace_back(end - start); arr_bckup = arr; diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 2411c322..b2b4cb1c 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -686,7 +686,7 @@ replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int16_t *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(int16_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, int16_t>( @@ -695,7 +695,7 @@ inline void avx512_qsort(int16_t *arr,uint64_t *indexes, int64_t arrsize) } template <> -inline void avx512_qsort(uint16_t *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(uint16_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, uint16_t>( diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 918bc9ca..457df984 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -682,7 +682,7 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int32_t *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(int32_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, int32_t>( @@ -691,7 +691,7 @@ inline void avx512_qsort(int32_t *arr,uint64_t *indexes, int64_t arrsiz } template <> -inline void avx512_qsort(uint32_t *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(uint32_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, uint32_t>( @@ -700,7 +700,7 @@ inline void avx512_qsort(uint32_t *arr,uint64_t *indexes, int64_t arrs } template <> -inline void avx512_qsort(float *arr,uint64_t *indexes, int64_t arrsize) +inline void avx512_qsort(float *arr, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(arr, arrsize); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index e05a54b3..7e8db546 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -57,10 +57,6 @@ struct zmm_vector { { return _knot_mask8(x); } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); - } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); @@ -168,10 +164,6 @@ struct zmm_vector { { return _knot_mask8(x); } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); - } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); @@ -269,10 +261,6 @@ struct zmm_vector { { return _knot_mask8(x); } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OS); - } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); @@ -344,7 +332,6 @@ struct zmm_vector { template X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); @@ -364,55 +351,6 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) return zmm; } -template ::zmm_t> -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_t &index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - key_zmm = cmp_merge( - key_zmm, - vtype::template shuffle(key_zmm), - index_zmm, - zmm_vector::template shuffle( - index_zmm), - 0xAA); - key_zmm = cmp_merge( - key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), - index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), - index_zmm), - 0xCC); - key_zmm = cmp_merge( - key_zmm, - vtype::template shuffle(key_zmm), - index_zmm, - zmm_vector::template shuffle( - index_zmm), - 0xAA); - key_zmm = cmp_merge( - key_zmm, - vtype::permutexvar(rev_index, key_zmm), - index_zmm, - zmm_vector::permutexvar(rev_index, index_zmm), - 0xF0); - key_zmm = cmp_merge( - key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), - index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), - index_zmm), - 0xCC); - key_zmm = cmp_merge( - key_zmm, - vtype::template shuffle(key_zmm), - index_zmm, - zmm_vector::template shuffle( - index_zmm), - 0xAA); - return key_zmm; -} // Assumes zmm is bitonic and performs a recursive half cleaner template X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) @@ -433,40 +371,7 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) zmm, vtype::template shuffle(zmm), 0xAA); return zmm; } -// Assumes zmm is bitonic and performs a recursive half cleaner -template ::zmm_t> -X86_SIMD_SORT_INLINE zmm_t -bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) -{ - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - key_zmm = cmp_merge( - key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), - index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), - index_zmm), - 0xF0); - // 2) half_cleaner[4] - key_zmm = cmp_merge( - key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), - index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), - index_zmm), - 0xCC); - // 3) half_cleaner[1] - key_zmm = cmp_merge( - key_zmm, - vtype::template shuffle(key_zmm), - index_zmm, - zmm_vector::template shuffle( - index_zmm), - 0xAA); - return key_zmm; -} // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner template X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm_t &zmm2) @@ -480,34 +385,7 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &zmm1, zmm_t &zmm2) zmm1 = bitonic_merge_zmm_64bit(zmm3); zmm2 = bitonic_merge_zmm_64bit(zmm4); } -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template ::zmm_t> -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, - zmm_t &key_zmm2, - index_t &index_zmm1, - index_t &index_zmm2) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); - index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); - zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); - zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); - - index_t index_zmm3 = zmm_vector::mask_mov( - index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); - index_t index_zmm4 = zmm_vector::mask_mov( - index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); - - // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); - index_zmm1 = index_zmm3; - index_zmm2 = index_zmm4; -} // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner template @@ -531,67 +409,7 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *zmm) zmm[2] = bitonic_merge_zmm_64bit(zmm2); zmm[3] = bitonic_merge_zmm_64bit(zmm3); } -template ::zmm_t> -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, - index_t *index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network - zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); - zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); - index_t index_zmm2r - = zmm_vector::permutexvar(rev_index, index_zmm[2]); - index_t index_zmm3r - = zmm_vector::permutexvar(rev_index, index_zmm[3]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); - - index_t index_zmm_t1 = zmm_vector::mask_mov( - index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_t index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); - index_t index_zmm_t2 = zmm_vector::mask_mov( - index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_t index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); - - // 2) Recursive half clearer: 16 - zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t3 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t4 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); - zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); - index_t index_zmm0 = zmm_vector::mask_mov( - index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); - index_t index_zmm1 = zmm_vector::mask_mov( - index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); - index_t index_zmm2 = zmm_vector::mask_mov( - index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); - index_t index_zmm3 = zmm_vector::mask_mov( - index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - - index_zmm[0] = index_zmm0; - index_zmm[1] = index_zmm1; - index_zmm[2] = index_zmm2; - index_zmm[3] = index_zmm3; -} template X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) { @@ -625,92 +443,7 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *zmm) zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); } -template ::zmm_t> -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, - index_t *index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); - zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); - zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); - zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); - index_t index_zmm4r - = zmm_vector::permutexvar(rev_index, index_zmm[4]); - index_t index_zmm5r - = zmm_vector::permutexvar(rev_index, index_zmm[5]); - index_t index_zmm6r - = zmm_vector::permutexvar(rev_index, index_zmm[6]); - index_t index_zmm7r - = zmm_vector::permutexvar(rev_index, index_zmm[7]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); - zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm5r); - zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm4r); - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); - zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); - - index_t index_zmm_t1 = zmm_vector::mask_mov( - index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_t index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); - index_t index_zmm_t2 = zmm_vector::mask_mov( - index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_t index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); - index_t index_zmm_t3 = zmm_vector::mask_mov( - index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_t index_zmm_m3 = zmm_vector::mask_mov( - index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); - index_t index_zmm_t4 = zmm_vector::mask_mov( - index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_t index_zmm_m4 = zmm_vector::mask_mov( - index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); - - zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t5 - = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t6 - = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t7 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t8 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - - index_zmm[0] = index_zmm_t1; - index_zmm[1] = index_zmm_t2; - index_zmm[2] = index_zmm_t3; - index_zmm[3] = index_zmm_t4; - index_zmm[4] = index_zmm_t5; - index_zmm[5] = index_zmm_t6; - index_zmm[6] = index_zmm_t7; - index_zmm[7] = index_zmm_t8; -} template X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) { @@ -784,331 +517,83 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); } -template ::zmm_t> -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, - index_t *index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); - zmm_t key_zmm9r = vtype::permutexvar(rev_index, key_zmm[9]); - zmm_t key_zmm10r = vtype::permutexvar(rev_index, key_zmm[10]); - zmm_t key_zmm11r = vtype::permutexvar(rev_index, key_zmm[11]); - zmm_t key_zmm12r = vtype::permutexvar(rev_index, key_zmm[12]); - zmm_t key_zmm13r = vtype::permutexvar(rev_index, key_zmm[13]); - zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); - zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); - - index_t index_zmm8r - = zmm_vector::permutexvar(rev_index, index_zmm[8]); - index_t index_zmm9r - = zmm_vector::permutexvar(rev_index, index_zmm[9]); - index_t index_zmm10r - = zmm_vector::permutexvar(rev_index, index_zmm[10]); - index_t index_zmm11r - = zmm_vector::permutexvar(rev_index, index_zmm[11]); - index_t index_zmm12r - = zmm_vector::permutexvar(rev_index, index_zmm[12]); - index_t index_zmm13r - = zmm_vector::permutexvar(rev_index, index_zmm[13]); - index_t index_zmm14r - = zmm_vector::permutexvar(rev_index, index_zmm[14]); - index_t index_zmm15r - = zmm_vector::permutexvar(rev_index, index_zmm[15]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); - zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); - zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm12r); - zmm_t key_zmm_t5 = vtype::min(key_zmm[4], key_zmm11r); - zmm_t key_zmm_t6 = vtype::min(key_zmm[5], key_zmm10r); - zmm_t key_zmm_t7 = vtype::min(key_zmm[6], key_zmm9r); - zmm_t key_zmm_t8 = vtype::min(key_zmm[7], key_zmm8r); - - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm15r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm14r); - zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm13r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm12r); - zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); - zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); - zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); - - index_t index_zmm_t1 = zmm_vector::mask_mov( - index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_t index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); - index_t index_zmm_t2 = zmm_vector::mask_mov( - index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_t index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); - index_t index_zmm_t3 = zmm_vector::mask_mov( - index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_t index_zmm_m3 = zmm_vector::mask_mov( - index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); - index_t index_zmm_t4 = zmm_vector::mask_mov( - index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_t index_zmm_m4 = zmm_vector::mask_mov( - index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); - - index_t index_zmm_t5 = zmm_vector::mask_mov( - index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); - index_t index_zmm_m5 = zmm_vector::mask_mov( - index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); - index_t index_zmm_t6 = zmm_vector::mask_mov( - index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); - index_t index_zmm_m6 = zmm_vector::mask_mov( - index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); - index_t index_zmm_t7 = zmm_vector::mask_mov( - index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); - index_t index_zmm_m7 = zmm_vector::mask_mov( - index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); - index_t index_zmm_t8 = zmm_vector::mask_mov( - index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); - index_t index_zmm_m8 = zmm_vector::mask_mov( - index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); - zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); - zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); - zmm_t key_zmm_t11 = vtype::permutexvar(rev_index, key_zmm_m6); - zmm_t key_zmm_t12 = vtype::permutexvar(rev_index, key_zmm_m5); - zmm_t key_zmm_t13 = vtype::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t9 - = zmm_vector::permutexvar(rev_index, index_zmm_m8); - index_t index_zmm_t10 - = zmm_vector::permutexvar(rev_index, index_zmm_m7); - index_t index_zmm_t11 - = zmm_vector::permutexvar(rev_index, index_zmm_m6); - index_t index_zmm_t12 - = zmm_vector::permutexvar(rev_index, index_zmm_m5); - index_t index_zmm_t13 - = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t14 - = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t15 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t16 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); - COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); - COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); - COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); - COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); - COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); - COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); - COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); - // - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); - - index_zmm[0] = index_zmm_t1; - index_zmm[1] = index_zmm_t2; - index_zmm[2] = index_zmm_t3; - index_zmm[3] = index_zmm_t4; - index_zmm[4] = index_zmm_t5; - index_zmm[5] = index_zmm_t6; - index_zmm[6] = index_zmm_t7; - index_zmm[7] = index_zmm_t8; - index_zmm[8] = index_zmm_t9; - index_zmm[9] = index_zmm_t10; - index_zmm[10] = index_zmm_t11; - index_zmm[11] = index_zmm_t12; - index_zmm[12] = index_zmm_t13; - index_zmm[13] = index_zmm_t14; - index_zmm[14] = index_zmm_t15; - index_zmm[15] = index_zmm_t16; -} template -X86_SIMD_SORT_INLINE void -sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) { typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype::zmm_t key_zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); - if (indexes) { - zmm_vector::zmm_t index_zmm - = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask, indexes); - vtype::mask_storeu( - keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); - zmm_vector::mask_storeu(indexes, load_mask, index_zmm); - } - else { - vtype::mask_storeu(keys, load_mask, sort_zmm_64bit(key_zmm)); - } + typename vtype::zmm_t zmm + = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); + vtype::mask_storeu(arr, load_mask, sort_zmm_64bit(zmm)); } template -X86_SIMD_SORT_INLINE void -sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void sort_16_64bit(type_t *arr, int32_t N) { if (N <= 8) { - sort_8_64bit(keys, indexes, N); + sort_8_64bit(arr, N); return; } using zmm_t = typename vtype::zmm_t; - using index_t = zmm_vector::zmm_t; - zmm_t key_zmm1 = vtype::loadu(keys); + zmm_t zmm1 = vtype::loadu(arr); typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); - - if (indexes) { - index_t index_zmm1 = zmm_vector::loadu(indexes); - index_t index_zmm2 = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask, indexes + 8); - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( - key_zmm1, key_zmm2, index_zmm1, index_zmm2); - zmm_vector::storeu(indexes, index_zmm1); - zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); - } - else { - key_zmm1 = sort_zmm_64bit(key_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2); - bitonic_merge_two_zmm_64bit(key_zmm1, key_zmm2); - } - - vtype::storeu(keys, key_zmm1); - vtype::mask_storeu(keys + 8, load_mask, key_zmm2); + zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 8); + zmm1 = sort_zmm_64bit(zmm1); + zmm2 = sort_zmm_64bit(zmm2); + bitonic_merge_two_zmm_64bit(zmm1, zmm2); + vtype::storeu(arr, zmm1); + vtype::mask_storeu(arr + 8, load_mask, zmm2); } template -X86_SIMD_SORT_INLINE void -sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void sort_32_64bit(type_t *arr, int32_t N) { if (N <= 16) { - sort_16_64bit(keys, indexes, N); + sort_16_64bit(arr, N); return; } using zmm_t = typename vtype::zmm_t; using opmask_t = typename vtype::opmask_t; - using index_t = zmm_vector::zmm_t; - zmm_t key_zmm[4]; - index_t index_zmm[4]; - - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - if (indexes) { - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - } - else { - key_zmm[0] = sort_zmm_64bit(key_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1]); - } + zmm_t zmm[4]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; load_mask1 = (combined_mask)&0xFF; load_mask2 = (combined_mask >> 8) & 0xFF; - key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); - key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); - - if (indexes) { - index_zmm[2] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::mask_storeu( - indexes + 16, load_mask1, index_zmm[2]); - zmm_vector::mask_storeu( - indexes + 24, load_mask2, index_zmm[3]); - } - else { - key_zmm[2] = sort_zmm_64bit(key_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm); - } - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::mask_storeu(keys + 16, load_mask1, key_zmm[2]); - vtype::mask_storeu(keys + 24, load_mask2, key_zmm[3]); + zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 16); + zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 24); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_four_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::mask_storeu(arr + 16, load_mask1, zmm[2]); + vtype::mask_storeu(arr + 24, load_mask2, zmm[3]); } template -X86_SIMD_SORT_INLINE void -sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void sort_64_64bit(type_t *arr, int32_t N) { if (N <= 32) { - sort_32_64bit(keys, indexes, N); + sort_32_64bit(arr, N); return; } using zmm_t = typename vtype::zmm_t; using opmask_t = typename vtype::opmask_t; - using index_t = zmm_vector::zmm_t; - zmm_t key_zmm[8]; - index_t index_zmm[8]; - - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - key_zmm[2] = vtype::loadu(keys + 16); - key_zmm[3] = vtype::loadu(keys + 24); - if (indexes) { - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - index_zmm[2] = zmm_vector::loadu(indexes + 16); - index_zmm[3] = zmm_vector::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - } - else { - key_zmm[0] = sort_zmm_64bit(key_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3]); - } + zmm_t zmm[8]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; // N-32 >= 1 @@ -1117,123 +602,57 @@ sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) load_mask2 = (combined_mask >> 8) & 0xFF; load_mask3 = (combined_mask >> 16) & 0xFF; load_mask4 = (combined_mask >> 24) & 0xFF; - key_zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 32); - key_zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 40); - key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); - key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); - - if (indexes) { - index_zmm[4] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::storeu(indexes + 16, index_zmm[2]); - zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::mask_storeu( - indexes + 32, load_mask1, index_zmm[4]); - zmm_vector::mask_storeu( - indexes + 40, load_mask2, index_zmm[5]); - zmm_vector::mask_storeu( - indexes + 48, load_mask3, index_zmm[6]); - zmm_vector::mask_storeu( - indexes + 56, load_mask4, index_zmm[7]); - } - else { - key_zmm[4] = sort_zmm_64bit(key_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5]); - bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm); - } - - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::storeu(keys + 16, key_zmm[2]); - vtype::storeu(keys + 24, key_zmm[3]); - vtype::mask_storeu(keys + 32, load_mask1, key_zmm[4]); - vtype::mask_storeu(keys + 40, load_mask2, key_zmm[5]); - vtype::mask_storeu(keys + 48, load_mask3, key_zmm[6]); - vtype::mask_storeu(keys + 56, load_mask4, key_zmm[7]); + zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 32); + zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 40); + zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 48); + zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 56); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_eight_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::mask_storeu(arr + 32, load_mask1, zmm[4]); + vtype::mask_storeu(arr + 40, load_mask2, zmm[5]); + vtype::mask_storeu(arr + 48, load_mask3, zmm[6]); + vtype::mask_storeu(arr + 56, load_mask4, zmm[7]); } template -X86_SIMD_SORT_INLINE void -sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) { if (N <= 64) { - sort_64_64bit(keys, indexes, N); + sort_64_64bit(arr, N); return; } using zmm_t = typename vtype::zmm_t; - using index_t = zmm_vector::zmm_t; using opmask_t = typename vtype::opmask_t; - zmm_t key_zmm[16]; - index_t index_zmm[16]; - - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - key_zmm[2] = vtype::loadu(keys + 16); - key_zmm[3] = vtype::loadu(keys + 24); - key_zmm[4] = vtype::loadu(keys + 32); - key_zmm[5] = vtype::loadu(keys + 40); - key_zmm[6] = vtype::loadu(keys + 48); - key_zmm[7] = vtype::loadu(keys + 56); - if (indexes != NULL) { - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - index_zmm[2] = zmm_vector::loadu(indexes + 16); - index_zmm[3] = zmm_vector::loadu(indexes + 24); - index_zmm[4] = zmm_vector::loadu(indexes + 32); - index_zmm[5] = zmm_vector::loadu(indexes + 40); - index_zmm[6] = zmm_vector::loadu(indexes + 48); - index_zmm[7] = zmm_vector::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - } - else { - key_zmm[0] = sort_zmm_64bit(key_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7]); - } - + zmm_t zmm[16]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[4] = vtype::loadu(arr + 32); + zmm[5] = vtype::loadu(arr + 40); + zmm[6] = vtype::loadu(arr + 48); + zmm[7] = vtype::loadu(arr + 56); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; @@ -1249,142 +668,63 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) load_mask7 = (combined_mask >> 48) & 0xFF; load_mask8 = (combined_mask >> 56) & 0xFF; } - key_zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 64); - key_zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 72); - key_zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 80); - key_zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 88); - key_zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, keys + 96); - key_zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, keys + 104); - key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); - key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); - - if (indexes) { - index_zmm[8] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( - key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( - key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( - key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( - key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::storeu(indexes + 16, index_zmm[2]); - zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::storeu(indexes + 32, index_zmm[4]); - zmm_vector::storeu(indexes + 40, index_zmm[5]); - zmm_vector::storeu(indexes + 48, index_zmm[6]); - zmm_vector::storeu(indexes + 56, index_zmm[7]); - zmm_vector::mask_storeu( - indexes + 64, load_mask1, index_zmm[8]); - zmm_vector::mask_storeu( - indexes + 72, load_mask2, index_zmm[9]); - zmm_vector::mask_storeu( - indexes + 80, load_mask3, index_zmm[10]); - zmm_vector::mask_storeu( - indexes + 88, load_mask4, index_zmm[11]); - zmm_vector::mask_storeu( - indexes + 96, load_mask5, index_zmm[12]); - zmm_vector::mask_storeu( - indexes + 104, load_mask6, index_zmm[13]); - zmm_vector::mask_storeu( - indexes + 112, load_mask7, index_zmm[14]); - zmm_vector::mask_storeu( - indexes + 120, load_mask8, index_zmm[15]); - } - else { - key_zmm[8] = sort_zmm_64bit(key_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15]); - bitonic_merge_two_zmm_64bit(key_zmm[0], key_zmm[1]); - bitonic_merge_two_zmm_64bit(key_zmm[2], key_zmm[3]); - bitonic_merge_two_zmm_64bit(key_zmm[4], key_zmm[5]); - bitonic_merge_two_zmm_64bit(key_zmm[6], key_zmm[7]); - bitonic_merge_two_zmm_64bit(key_zmm[8], key_zmm[9]); - bitonic_merge_two_zmm_64bit(key_zmm[10], key_zmm[11]); - bitonic_merge_two_zmm_64bit(key_zmm[12], key_zmm[13]); - bitonic_merge_two_zmm_64bit(key_zmm[14], key_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm); - } - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::storeu(keys + 16, key_zmm[2]); - vtype::storeu(keys + 24, key_zmm[3]); - vtype::storeu(keys + 32, key_zmm[4]); - vtype::storeu(keys + 40, key_zmm[5]); - vtype::storeu(keys + 48, key_zmm[6]); - vtype::storeu(keys + 56, key_zmm[7]); - vtype::mask_storeu(keys + 64, load_mask1, key_zmm[8]); - vtype::mask_storeu(keys + 72, load_mask2, key_zmm[9]); - vtype::mask_storeu(keys + 80, load_mask3, key_zmm[10]); - vtype::mask_storeu(keys + 88, load_mask4, key_zmm[11]); - vtype::mask_storeu(keys + 96, load_mask5, key_zmm[12]); - vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); - vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); - vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); + zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); + zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 72); + zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 80); + zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 88); + zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 96); + zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 104); + zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 112); + zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 120); + zmm[8] = sort_zmm_64bit(zmm[8]); + zmm[9] = sort_zmm_64bit(zmm[9]); + zmm[10] = sort_zmm_64bit(zmm[10]); + zmm[11] = sort_zmm_64bit(zmm[11]); + zmm[12] = sort_zmm_64bit(zmm[12]); + zmm[13] = sort_zmm_64bit(zmm[13]); + zmm[14] = sort_zmm_64bit(zmm[14]); + zmm[15] = sort_zmm_64bit(zmm[15]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); + bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); + bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); + bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_four_zmm_64bit(zmm + 8); + bitonic_merge_four_zmm_64bit(zmm + 12); + bitonic_merge_eight_zmm_64bit(zmm); + bitonic_merge_eight_zmm_64bit(zmm + 8); + bitonic_merge_sixteen_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::storeu(arr + 32, zmm[4]); + vtype::storeu(arr + 40, zmm[5]); + vtype::storeu(arr + 48, zmm[6]); + vtype::storeu(arr + 56, zmm[7]); + vtype::mask_storeu(arr + 64, load_mask1, zmm[8]); + vtype::mask_storeu(arr + 72, load_mask2, zmm[9]); + vtype::mask_storeu(arr + 80, load_mask3, zmm[10]); + vtype::mask_storeu(arr + 88, load_mask4, zmm[11]); + vtype::mask_storeu(arr + 96, load_mask5, zmm[12]); + vtype::mask_storeu(arr + 104, load_mask6, zmm[13]); + vtype::mask_storeu(arr + 112, load_mask7, zmm[14]); + vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); } template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, - uint64_t *indexes, +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, const int64_t left, const int64_t right) { // median of 8 int64_t size = (right - left) / 8; using zmm_t = typename vtype::zmm_t; - using index_t = zmm_vector::zmm_t; __m512i rand_index = _mm512_set_epi64(left + size, left + 2 * size, left + 3 * size, @@ -1393,92 +733,40 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, left + 6 * size, left + 7 * size, left + 8 * size); - zmm_t key_vec = vtype::template i64gather(rand_index, keys); - - index_t index_vec; - zmm_t sort; - if (indexes) { - index_vec = zmm_vector::template i64gather( - rand_index, indexes); - sort = sort_zmm_64bit(key_vec, index_vec); - } - else { - //index_vec=vtype::template i64gather(rand_index, indexes); - sort = sort_zmm_64bit(key_vec); - } + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); // pivot will never be a nan, since there are no nan's! - + zmm_t sort = sort_zmm_64bit(rand_vec); return ((type_t *)&sort)[4]; } template -inline void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) -{ - int64_t i = idx; - while (true) { - int64_t j = 2 * i + 1; - if (j >= size || j < 0) { break; } - int k = j + 1; - if (k < size && keys[j] < keys[k]) { j = k; } - if (keys[j] < keys[i]) { break; } - std::swap(keys[i], keys[j]); - std::swap(indexes[i], indexes[j]); - i = j; - } -} -template -inline void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) -{ - for (int64_t i = size / 2 - 1; i >= 0; i--) { - heapify(keys, indexes, i, size); - } - for (int64_t i = size - 1; i > 0; i--) { - std::swap(keys[0], keys[i]); - std::swap(indexes[0], indexes[i]); - heapify(keys, indexes, 0, i); - } -} - -template -inline void qsort_64bit_(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - int64_t max_iters) +static void +qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - if (indexes) - heap_sort(keys + left, indexes + left, right - left + 1); - else - std::sort(keys + left, keys + right + 1); + std::sort(arr + left, arr + right + 1); return; } /* * Base case: use bitonic networks to sort arrays <= 128 */ if (right + 1 - left <= 128) { - if (indexes) - sort_128_64bit( - keys + left, indexes + left, (int32_t)(right + 1 - left)); - else - sort_128_64bit( - keys + left, (uint64_t *)NULL, (int32_t)(right + 1 - left)); + sort_128_64bit(arr + left, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot_64bit(keys, indexes, left, right); + type_t pivot = get_pivot_64bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( - keys, indexes, left, right + 1, pivot, &smallest, &biggest); + arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - qsort_64bit_( - keys, indexes, left, pivot_index - 1, max_iters - 1); + qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); + qsort_64bit_(arr, pivot_index, right, max_iters - 1); } X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) @@ -1507,34 +795,31 @@ replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void -avx512_qsort(int64_t *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort(int64_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, int64_t>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -inline void -avx512_qsort(uint64_t *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort(uint64_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_64bit_, uint64_t>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -inline void -avx512_qsort(double *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort(double *arr, int64_t arrsize) { if (arrsize > 1) { - int64_t nan_count = replace_nan_with_inf(keys, arrsize); + int64_t nan_count = replace_nan_with_inf(arr, arrsize); qsort_64bit_, double>( - keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan(keys, arrsize, nan_count); + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); } } #endif // AVX512_QSORT_64BIT diff --git a/src/avx512-common-keyvaluesort.h b/src/avx512-common-keyvaluesort.h new file mode 100644 index 00000000..f4c642b3 --- /dev/null +++ b/src/avx512-common-keyvaluesort.h @@ -0,0 +1,291 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * Copyright (C) 2021 Serge Sans Paille + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * Serge Sans Paille + * ****************************************************************/ + +#ifndef AVX512_QSORT_COMMON_KV +#define AVX512_QSORT_COMMON_KV + +/* + * Quicksort using AVX-512. The ideas and code are based on these two research + * papers [1] and [2]. On a high level, the idea is to vectorize quicksort + * partitioning using AVX-512 compressstore instructions. If the array size is + * < 128, then use Bitonic sorting network implemented on 512-bit registers. + * The precise network definitions depend on the dtype and are defined in + * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and + * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting + * network. The core implementations of the vectorized qsort functions + * avx512_qsort(T*, int64_t) are modified versions of avx2 quicksort + * presented in the paper [2] and source code associated with that paper [3]. + * + * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types + * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ + * + * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel + * Skylake https://arxiv.org/pdf/1704.08579.pdf + * + * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: MIT + * + * [4] http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 + * + */ + +#include +#include +#include +#include +#include + +#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() +#define X86_SIMD_SORT_INFINITYH 0x7c00 +#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 +#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() +#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() +#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() +#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) +#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) +#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) +#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) +#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) +#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) +#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) +#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) +#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) +#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d + +#ifdef _MSC_VER +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __forceinline +#elif defined(__CYGWIN__) +/* + * Force inline in cygwin to work around a compiler bug. See + * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 + */ +#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#elif defined(__GNUC__) +#define X86_SIMD_SORT_INLINE static inline +#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) +#else +#define X86_SIMD_SORT_INLINE static +#define X86_SIMD_SORT_FINLINE static +#endif + +template +struct zmm_kv_vector; + +template +inline void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); + +using index_t = __m512i; +//using index_type = zmm_kv_vector; + +template > +static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) +{ + //COEX(key1,key2); + mm_t key_t1 = vtype::min(key1, key2); + mm_t key_t2 = vtype::max(key1, key2); + + index_t index_t1 + = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); + index_t index_t2 + = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); + + key1 = key_t1; + key2 = key_t2; + index1 = index_t1; + index2 = index_t2; +} +template > +static inline zmm_t cmp_merge(zmm_t in1, + zmm_t in2, + index_t &indexes1, + index_t indexes2, + opmask_t mask) +{ + zmm_t tmp_keys = cmp_merge(in1, in2, mask); + indexes1 = index_type::mask_mov( + indexes2, vtype::eq(tmp_keys, in1), indexes1); + return tmp_keys; // 0 -> min, 1 -> max +} +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template > +static inline int32_t partition_vec(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + const zmm_t keys_vec, + const index_t indexes_vec, + const zmm_t pivot_vec, + zmm_t *smallest_vec, + zmm_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(keys_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + vtype::mask_compressstoreu( + keys + left, vtype::knot_opmask(gt_mask), keys_vec); + vtype::mask_compressstoreu( + keys + right - amount_gt_pivot, gt_mask, keys_vec); + index_type::mask_compressstoreu( + indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); + index_type::mask_compressstoreu( + indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + *smallest_vec = vtype::min(keys_vec, *smallest_vec); + *biggest_vec = vtype::max(keys_vec, *biggest_vec); + return amount_gt_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template > +static inline int64_t partition_avx512(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, keys[left]); + *biggest = std::max(*biggest, keys[left]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + std::swap(indexes[left], indexes[right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + zmm_t keys_vec = vtype::loadu(keys + left); + int32_t amount_gt_pivot; + + index_t indexes_vec = index_type::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + zmm_t keys_vec_left = vtype::loadu(keys + left); + zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); + index_t indexes_vec_left; + index_t indexes_vec_right; + indexes_vec_left = index_type::loadu(indexes + left); + indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); + + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + zmm_t keys_vec; + index_t indexes_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + keys_vec = vtype::loadu(keys + right); + indexes_vec = index_type::loadu(indexes + right); + } + else { + keys_vec = vtype::loadu(keys + left); + indexes_vec = index_type::loadu(indexes + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot; + + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot; + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + l_store + vtype::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} +#endif // AVX512_QSORT_COMMON_KV diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 99ed4076..d1f6cbb4 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -85,10 +85,7 @@ template struct zmm_vector; template -inline void avx512_qsort(T *keys, uint64_t *indexes, int64_t arrsize); - -using index_t = __m512i; -//using index_type = zmm_vector; +void avx512_qsort(T *arr, int64_t arrsize); template bool comparison_func(const T &a, const T &b) @@ -106,25 +103,7 @@ static void COEX(mm_t &a, mm_t &b) a = vtype::min(a, b); b = vtype::max(temp, b); } -template > -static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) -{ - //COEX(key1,key2); - mm_t key_t1 = vtype::min(key1, key2); - mm_t key_t2 = vtype::max(key1, key2); - - index_t index_t1 - = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); - index_t index_t2 - = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); - key1 = key_t1; - key2 = key_t2; - index1 = index_t1; - index2 = index_t2; -} template @@ -134,21 +113,7 @@ static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask) zmm_t max = vtype::max(in2, in1); return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max } -template > -static inline zmm_t cmp_merge(zmm_t in1, - zmm_t in2, - index_t &indexes1, - index_t indexes2, - opmask_t mask) -{ - zmm_t tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = index_type::mask_mov( - indexes2, vtype::eq(tmp_keys, in1), indexes1); - return tmp_keys; // 0 -> min, 1 -> max -} + /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -173,35 +138,7 @@ static inline int32_t partition_vec(type_t *arr, *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; } -template > -static inline int32_t partition_vec(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - const zmm_t keys_vec, - const index_t indexes_vec, - const zmm_t pivot_vec, - zmm_t *smallest_vec, - zmm_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(keys_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - vtype::mask_compressstoreu( - keys + left, vtype::knot_opmask(gt_mask), keys_vec); - vtype::mask_compressstoreu( - keys + right - amount_gt_pivot, gt_mask, keys_vec); - index_type::mask_compressstoreu( - indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); - index_type::mask_compressstoreu( - indexes + right - amount_gt_pivot, gt_mask, indexes_vec); - *smallest_vec = vtype::min(keys_vec, *smallest_vec); - *biggest_vec = vtype::max(keys_vec, *biggest_vec); - return amount_gt_pivot; -} + /* * Parition an array based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -307,172 +244,4 @@ static inline int64_t partition_avx512(type_t *arr, *biggest = vtype::reducemax(max_vec); return l_store; } - -template > -static inline int64_t partition_avx512(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, keys[left]); - *biggest = std::max(*biggest, keys[left]); - if (keys[left] > pivot) { - right--; - std::swap(keys[left], keys[right]); - if (indexes) std::swap(indexes[left], indexes[right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - zmm_t keys_vec = vtype::loadu(keys + left); - int32_t amount_gt_pivot; - if (indexes) { - index_t indexes_vec = index_type::loadu(indexes + left); - amount_gt_pivot = partition_vec(keys, - indexes, - left, - left + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - } - else { - amount_gt_pivot = partition_vec(keys, - left, - left + vtype::numlanes, - keys_vec, - pivot_vec, - &min_vec, - &max_vec); - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - zmm_t keys_vec_left = vtype::loadu(keys + left); - zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); - index_t indexes_vec_left; - index_t indexes_vec_right; - if (indexes) { - indexes_vec_left = index_type::loadu(indexes + left); - indexes_vec_right - = index_type::loadu(indexes + (right - vtype::numlanes)); - } - - // store points of the vectors - int64_t r_store = right - vtype::numlanes; - int64_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - zmm_t keys_vec; - index_t indexes_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - keys_vec = vtype::loadu(keys + right); - if (indexes) indexes_vec = index_type::loadu(indexes + right); - } - else { - keys_vec = vtype::loadu(keys + left); - if (indexes) indexes_vec = index_type::loadu(indexes + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot; - if (indexes) - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - else - amount_gt_pivot = partition_vec(keys, - l_store, - r_store + vtype::numlanes, - keys_vec, - pivot_vec, - &min_vec, - &max_vec); - - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot; - if (indexes) { - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - indexes_vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - indexes_vec_right, - pivot_vec, - &min_vec, - &max_vec); - } - else { - amount_gt_pivot = partition_vec(keys, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(keys, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - pivot_vec, - &min_vec, - &max_vec); - } - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} #endif // AVX512_QSORT_COMMON diff --git a/src/avx512_64bit_keyvaluesort.hpp b/src/avx512_64bit_keyvaluesort.hpp new file mode 100644 index 00000000..49cedca6 --- /dev/null +++ b/src/avx512_64bit_keyvaluesort.hpp @@ -0,0 +1,1276 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX512_QSORT_64BIT_KV +#define AVX512_QSORT_64BIT_KV + +#include "avx512-common-keyvaluesort.h" + +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + +template <> +struct zmm_kv_vector { + using type_t = int64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT64; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT64; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_kv_vector { + using type_t = uint64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT64; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_kv_vector { + using type_t = double; + using zmm_t = __m512d; + using ymm_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITY; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITY; + } + static zmm_t zmm_max() + { + return _mm512_set1_pd(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OS); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_pd(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_pd(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_pd(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_pd(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_pd(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_pd(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_pd(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_pd(mem, x); + } +}; + +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ + +template ::zmm_t> +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_t &index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_kv_vector::template shuffle( + index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + index_zmm, + zmm_kv_vector::permutexvar( + _mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_kv_vector::template shuffle( + index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(rev_index, key_zmm), + index_zmm, + zmm_kv_vector::permutexvar(rev_index, index_zmm), + 0xF0); + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_kv_vector::permutexvar( + _mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_kv_vector::template shuffle( + index_zmm), + 0xAA); + return key_zmm; +} +// Assumes zmm is bitonic and performs a recursive half cleaner +template ::zmm_t> +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit( + zmm_t key_zmm, zmm_kv_vector::zmm_t &index_zmm) +{ + + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + index_zmm, + zmm_kv_vector::permutexvar( + _mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + 0xF0); + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + zmm_kv_vector::permutexvar( + _mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + 0xCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, + vtype::template shuffle(key_zmm), + index_zmm, + zmm_kv_vector::template shuffle( + index_zmm), + 0xAA); + return key_zmm; +} +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, + zmm_t &key_zmm2, + index_t &index_zmm1, + index_t &index_zmm2) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); + index_zmm2 = zmm_kv_vector::permutexvar(rev_index, index_zmm2); + + zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); + zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); + + index_t index_zmm3 = zmm_kv_vector::mask_mov( + index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); + index_t index_zmm4 = zmm_kv_vector::mask_mov( + index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); + + // 2) Recursive half cleaner for each + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + index_zmm1 = index_zmm3; + index_zmm2 = index_zmm4; +} +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network + zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); + zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); + index_t index_zmm2r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[2]); + index_t index_zmm3r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[3]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); + + index_t index_zmm_t1 = zmm_kv_vector::mask_mov( + index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_kv_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_t index_zmm_t2 = zmm_kv_vector::mask_mov( + index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_kv_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); + + // 2) Recursive half clearer: 16 + zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); + index_t index_zmm_t3 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t4 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m1); + + zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); + zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); + + index_t index_zmm0 = zmm_kv_vector::mask_mov( + index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_t index_zmm1 = zmm_kv_vector::mask_mov( + index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_t index_zmm2 = zmm_kv_vector::mask_mov( + index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_t index_zmm3 = zmm_kv_vector::mask_mov( + index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + + index_zmm[0] = index_zmm0; + index_zmm[1] = index_zmm1; + index_zmm[2] = index_zmm2; + index_zmm[3] = index_zmm3; +} +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); + zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); + zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); + zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); + index_t index_zmm4r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[4]); + index_t index_zmm5r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[5]); + index_t index_zmm6r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[6]); + index_t index_zmm7r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[7]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); + zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm5r); + zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm4r); + + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); + zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); + + index_t index_zmm_t1 = zmm_kv_vector::mask_mov( + index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_kv_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_t index_zmm_t2 = zmm_kv_vector::mask_mov( + index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_kv_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_t index_zmm_t3 = zmm_kv_vector::mask_mov( + index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_t index_zmm_m3 = zmm_kv_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_t index_zmm_t4 = zmm_kv_vector::mask_mov( + index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_t index_zmm_m4 = zmm_kv_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); + + zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); + index_t index_zmm_t5 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t6 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t7 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t8 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; +} +template ::zmm_t> +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, + index_t *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); + zmm_t key_zmm9r = vtype::permutexvar(rev_index, key_zmm[9]); + zmm_t key_zmm10r = vtype::permutexvar(rev_index, key_zmm[10]); + zmm_t key_zmm11r = vtype::permutexvar(rev_index, key_zmm[11]); + zmm_t key_zmm12r = vtype::permutexvar(rev_index, key_zmm[12]); + zmm_t key_zmm13r = vtype::permutexvar(rev_index, key_zmm[13]); + zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); + zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); + + index_t index_zmm8r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[8]); + index_t index_zmm9r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[9]); + index_t index_zmm10r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[10]); + index_t index_zmm11r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[11]); + index_t index_zmm12r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[12]); + index_t index_zmm13r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[13]); + index_t index_zmm14r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[14]); + index_t index_zmm15r + = zmm_kv_vector::permutexvar(rev_index, index_zmm[15]); + + zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); + zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); + zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); + zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm12r); + zmm_t key_zmm_t5 = vtype::min(key_zmm[4], key_zmm11r); + zmm_t key_zmm_t6 = vtype::min(key_zmm[5], key_zmm10r); + zmm_t key_zmm_t7 = vtype::min(key_zmm[6], key_zmm9r); + zmm_t key_zmm_t8 = vtype::min(key_zmm[7], key_zmm8r); + + zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm15r); + zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm14r); + zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm13r); + zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm12r); + zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); + zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); + zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); + zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); + + index_t index_zmm_t1 = zmm_kv_vector::mask_mov( + index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_t index_zmm_m1 = zmm_kv_vector::mask_mov( + index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_t index_zmm_t2 = zmm_kv_vector::mask_mov( + index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_t index_zmm_m2 = zmm_kv_vector::mask_mov( + index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_t index_zmm_t3 = zmm_kv_vector::mask_mov( + index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_t index_zmm_m3 = zmm_kv_vector::mask_mov( + index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_t index_zmm_t4 = zmm_kv_vector::mask_mov( + index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_t index_zmm_m4 = zmm_kv_vector::mask_mov( + index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_t index_zmm_t5 = zmm_kv_vector::mask_mov( + index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_t index_zmm_m5 = zmm_kv_vector::mask_mov( + index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_t index_zmm_t6 = zmm_kv_vector::mask_mov( + index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_t index_zmm_m6 = zmm_kv_vector::mask_mov( + index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_t index_zmm_t7 = zmm_kv_vector::mask_mov( + index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_t index_zmm_m7 = zmm_kv_vector::mask_mov( + index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_t index_zmm_t8 = zmm_kv_vector::mask_mov( + index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_t index_zmm_m8 = zmm_kv_vector::mask_mov( + index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); + + zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); + zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); + zmm_t key_zmm_t11 = vtype::permutexvar(rev_index, key_zmm_m6); + zmm_t key_zmm_t12 = vtype::permutexvar(rev_index, key_zmm_m5); + zmm_t key_zmm_t13 = vtype::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); + index_t index_zmm_t9 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m8); + index_t index_zmm_t10 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m7); + index_t index_zmm_t11 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m6); + index_t index_zmm_t12 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m5); + index_t index_zmm_t13 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m4); + index_t index_zmm_t14 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m3); + index_t index_zmm_t15 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m2); + index_t index_zmm_t16 + = zmm_kv_vector::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + // + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; + index_zmm[8] = index_zmm_t9; + index_zmm[9] = index_zmm_t10; + index_zmm[10] = index_zmm_t11; + index_zmm[11] = index_zmm_t12; + index_zmm[12] = index_zmm_t13; + index_zmm[13] = index_zmm_t14; + index_zmm[14] = index_zmm_t15; + index_zmm[15] = index_zmm_t16; +} +template +X86_SIMD_SORT_INLINE void +sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; + typename vtype::zmm_t key_zmm + = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); + + zmm_kv_vector::zmm_t index_zmm + = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask, indexes); + vtype::mask_storeu( + keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); + zmm_kv_vector::mask_storeu(indexes, load_mask, index_zmm); +} + +template +X86_SIMD_SORT_INLINE void +sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 8) { + sort_8_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using index_t = zmm_kv_vector::zmm_t; + + typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + + zmm_t key_zmm1 = vtype::loadu(keys); + zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); + + index_t index_zmm1 = zmm_kv_vector::loadu(indexes); + index_t index_zmm2 = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask, indexes + 8); + + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( + key_zmm1, key_zmm2, index_zmm1, index_zmm2); + + zmm_kv_vector::storeu(indexes, index_zmm1); + zmm_kv_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); + + vtype::storeu(keys, key_zmm1); + vtype::mask_storeu(keys + 8, load_mask, key_zmm2); +} + +template +X86_SIMD_SORT_INLINE void +sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 16) { + sort_16_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + using index_t = zmm_kv_vector::zmm_t; + zmm_t key_zmm[4]; + index_t index_zmm[4]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + + index_zmm[0] = zmm_kv_vector::loadu(indexes); + index_zmm[1] = zmm_kv_vector::loadu(indexes + 8); + + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); + key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); + + index_zmm[2] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask2, indexes + 24); + + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + + zmm_kv_vector::storeu(indexes, index_zmm[0]); + zmm_kv_vector::storeu(indexes + 8, index_zmm[1]); + zmm_kv_vector::mask_storeu( + indexes + 16, load_mask1, index_zmm[2]); + zmm_kv_vector::mask_storeu( + indexes + 24, load_mask2, index_zmm[3]); + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::mask_storeu(keys + 16, load_mask1, key_zmm[2]); + vtype::mask_storeu(keys + 24, load_mask2, key_zmm[3]); +} + +template +X86_SIMD_SORT_INLINE void +sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 32) { + sort_32_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + using index_t = zmm_kv_vector::zmm_t; + zmm_t key_zmm[8]; + index_t index_zmm[8]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[2] = vtype::loadu(keys + 16); + key_zmm[3] = vtype::loadu(keys + 24); + + index_zmm[0] = zmm_kv_vector::loadu(indexes); + index_zmm[1] = zmm_kv_vector::loadu(indexes + 8); + index_zmm[2] = zmm_kv_vector::loadu(indexes + 16); + index_zmm[3] = zmm_kv_vector::loadu(indexes + 24); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + // N-32 >= 1 + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + key_zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 32); + key_zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 40); + key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); + key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); + + index_zmm[4] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + + zmm_kv_vector::storeu(indexes, index_zmm[0]); + zmm_kv_vector::storeu(indexes + 8, index_zmm[1]); + zmm_kv_vector::storeu(indexes + 16, index_zmm[2]); + zmm_kv_vector::storeu(indexes + 24, index_zmm[3]); + zmm_kv_vector::mask_storeu( + indexes + 32, load_mask1, index_zmm[4]); + zmm_kv_vector::mask_storeu( + indexes + 40, load_mask2, index_zmm[5]); + zmm_kv_vector::mask_storeu( + indexes + 48, load_mask3, index_zmm[6]); + zmm_kv_vector::mask_storeu( + indexes + 56, load_mask4, index_zmm[7]); + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::storeu(keys + 16, key_zmm[2]); + vtype::storeu(keys + 24, key_zmm[3]); + vtype::mask_storeu(keys + 32, load_mask1, key_zmm[4]); + vtype::mask_storeu(keys + 40, load_mask2, key_zmm[5]); + vtype::mask_storeu(keys + 48, load_mask3, key_zmm[6]); + vtype::mask_storeu(keys + 56, load_mask4, key_zmm[7]); +} + +template +X86_SIMD_SORT_INLINE void +sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +{ + if (N <= 64) { + sort_64_64bit(keys, indexes, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using index_t = zmm_kv_vector::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t key_zmm[16]; + index_t index_zmm[16]; + + key_zmm[0] = vtype::loadu(keys); + key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[2] = vtype::loadu(keys + 16); + key_zmm[3] = vtype::loadu(keys + 24); + key_zmm[4] = vtype::loadu(keys + 32); + key_zmm[5] = vtype::loadu(keys + 40); + key_zmm[6] = vtype::loadu(keys + 48); + key_zmm[7] = vtype::loadu(keys + 56); + + index_zmm[0] = zmm_kv_vector::loadu(indexes); + index_zmm[1] = zmm_kv_vector::loadu(indexes + 8); + index_zmm[2] = zmm_kv_vector::loadu(indexes + 16); + index_zmm[3] = zmm_kv_vector::loadu(indexes + 24); + index_zmm[4] = zmm_kv_vector::loadu(indexes + 32); + index_zmm[5] = zmm_kv_vector::loadu(indexes + 40); + index_zmm[6] = zmm_kv_vector::loadu(indexes + 48); + index_zmm[7] = zmm_kv_vector::loadu(indexes + 56); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + if (N != 128) { + uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + } + key_zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 64); + key_zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 72); + key_zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 80); + key_zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 88); + key_zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, keys + 96); + key_zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, keys + 104); + key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); + key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); + + index_zmm[8] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = zmm_kv_vector::mask_loadu( + zmm_kv_vector::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); + bitonic_merge_two_zmm_64bit( + key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); + bitonic_merge_two_zmm_64bit( + key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); + bitonic_merge_two_zmm_64bit( + key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); + zmm_kv_vector::storeu(indexes, index_zmm[0]); + zmm_kv_vector::storeu(indexes + 8, index_zmm[1]); + zmm_kv_vector::storeu(indexes + 16, index_zmm[2]); + zmm_kv_vector::storeu(indexes + 24, index_zmm[3]); + zmm_kv_vector::storeu(indexes + 32, index_zmm[4]); + zmm_kv_vector::storeu(indexes + 40, index_zmm[5]); + zmm_kv_vector::storeu(indexes + 48, index_zmm[6]); + zmm_kv_vector::storeu(indexes + 56, index_zmm[7]); + zmm_kv_vector::mask_storeu( + indexes + 64, load_mask1, index_zmm[8]); + zmm_kv_vector::mask_storeu( + indexes + 72, load_mask2, index_zmm[9]); + zmm_kv_vector::mask_storeu( + indexes + 80, load_mask3, index_zmm[10]); + zmm_kv_vector::mask_storeu( + indexes + 88, load_mask4, index_zmm[11]); + zmm_kv_vector::mask_storeu( + indexes + 96, load_mask5, index_zmm[12]); + zmm_kv_vector::mask_storeu( + indexes + 104, load_mask6, index_zmm[13]); + zmm_kv_vector::mask_storeu( + indexes + 112, load_mask7, index_zmm[14]); + zmm_kv_vector::mask_storeu( + indexes + 120, load_mask8, index_zmm[15]); + + vtype::storeu(keys, key_zmm[0]); + vtype::storeu(keys + 8, key_zmm[1]); + vtype::storeu(keys + 16, key_zmm[2]); + vtype::storeu(keys + 24, key_zmm[3]); + vtype::storeu(keys + 32, key_zmm[4]); + vtype::storeu(keys + 40, key_zmm[5]); + vtype::storeu(keys + 48, key_zmm[6]); + vtype::storeu(keys + 56, key_zmm[7]); + vtype::mask_storeu(keys + 64, load_mask1, key_zmm[8]); + vtype::mask_storeu(keys + 72, load_mask2, key_zmm[9]); + vtype::mask_storeu(keys + 80, load_mask3, key_zmm[10]); + vtype::mask_storeu(keys + 88, load_mask4, key_zmm[11]); + vtype::mask_storeu(keys + 96, load_mask5, key_zmm[12]); + vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); + vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); + vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, + uint64_t *indexes, + const int64_t left, + const int64_t right) +{ + // median of 8 + int64_t size = (right - left) / 8; + using zmm_t = typename vtype::zmm_t; + using index_t = zmm_kv_vector::zmm_t; + __m512i rand_index = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + zmm_t key_vec = vtype::template i64gather(rand_index, keys); + + index_t index_vec; + zmm_t sort; + index_vec = zmm_kv_vector::template i64gather( + rand_index, indexes); + sort = sort_zmm_64bit(key_vec, index_vec); + // pivot will never be a nan, since there are no nan's! + + return ((type_t *)&sort)[4]; +} + +template +inline void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) +{ + int64_t i = idx; + while (true) { + int64_t j = 2 * i + 1; + if (j >= size || j < 0) { break; } + int k = j + 1; + if (k < size && keys[j] < keys[k]) { j = k; } + if (keys[j] < keys[i]) { break; } + std::swap(keys[i], keys[j]); + std::swap(indexes[i], indexes[j]); + i = j; + } +} +template +inline void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) +{ + for (int64_t i = size / 2 - 1; i >= 0; i--) { + heapify(keys, indexes, i, size); + } + for (int64_t i = size - 1; i > 0; i--) { + std::swap(keys[0], keys[i]); + std::swap(indexes[0], indexes[i]); + heapify(keys, indexes, 0, i); + } +} + +template +inline void qsort_64bit_(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + heap_sort(keys + left, indexes + left, right - left + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + + sort_128_64bit( + keys + left, indexes + left, (int32_t)(right + 1 - left)); + return; + } + + type_t pivot = get_pivot_64bit(keys, indexes, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512( + keys, indexes, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + qsort_64bit_( + keys, indexes, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); +} + +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf_kv(double *arr, + int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask8 loadmask = 0xFF; + while (arrsize > 0) { + if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } + __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); + __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); + arr += 8; + arrsize -= 8; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan_kv(double *arr, int64_t arrsize, int64_t nan_count) +{ + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + arr[ii] = std::nan("1"); + nan_count -= 1; + } +} + +template <> +inline void +avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) +{ + if (arrsize > 1) { + qsort_64bit_, int64_t>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +inline void +avx512_qsort_kv(uint64_t *keys, uint64_t *indexes, int64_t arrsize) +{ + if (arrsize > 1) { + qsort_64bit_, uint64_t>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +inline void +avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) +{ + if (arrsize > 1) { + int64_t nan_count = replace_nan_with_inf(keys, arrsize); + qsort_64bit_, double>( + keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan_kv(keys, arrsize, nan_count); + } +} +#endif // AVX512_QSORT_64BIT_KV diff --git a/tests/test_all.cpp b/tests/test_all.cpp index 90219d10..a8bc7411 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -6,6 +6,7 @@ #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" #include "avx512-64bit-qsort.hpp" +#include "avx512_64bit_keyvaluesort.hpp" #include "cpuinfo.h" #include "rand_array.h" #include @@ -34,7 +35,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) sortedarr = arr; /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); - avx512_qsort(arr.data(), NULL, arr.size()); + avx512_qsort(arr.data(), arr.size()); ASSERT_EQ(sortedarr, arr); arr.clear(); sortedarr.clear(); @@ -89,7 +90,7 @@ TEST(TestKeyValueSort, KeyValueSort) } /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end(), compare); - avx512_qsort(keys.data(), values.data(), keys.size()); + avx512_qsort_kv(keys.data(), values.data(), keys.size()); //ASSERT_EQ(sortedarr, arr); for (size_t i = 0; i < keys.size(); i++) { ASSERT_EQ(keys[i], sortedarr[i].key); From 17b534c3cc56bc74b9599075068481b54152b09a Mon Sep 17 00:00:00 2001 From: ruclz Date: Thu, 23 Feb 2023 14:01:26 +0800 Subject: [PATCH 11/16] Creat key-value sort hpp file. --- ...vx512_64bit_keyvaluesort.hpp => avx512-64bit-keyvaluesort.hpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/{avx512_64bit_keyvaluesort.hpp => avx512-64bit-keyvaluesort.hpp} (100%) diff --git a/src/avx512_64bit_keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp similarity index 100% rename from src/avx512_64bit_keyvaluesort.hpp rename to src/avx512-64bit-keyvaluesort.hpp From 4927fe274d860c829a0044c4fc4233a6a6f6fcfe Mon Sep 17 00:00:00 2001 From: ruclz Date: Thu, 23 Feb 2023 14:16:54 +0800 Subject: [PATCH 12/16] Sorry, I forgot to merge the test_all.cpp in the last commit. --- tests/test_all.cpp | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/test_all.cpp b/tests/test_all.cpp index e3e7eb39..35330fa8 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -5,8 +5,8 @@ #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-keyvaluesort.hpp" #include "avx512-64bit-qsort.hpp" -#include "avx512_64bit_keyvaluesort.hpp" #include "cpuinfo.h" #include "rand_array.h" #include @@ -35,11 +35,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) sortedarr = arr; /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); -<<<<<<< HEAD avx512_qsort(arr.data(), arr.size()); -======= - avx512_qsort(arr.data(), NULL, arr.size()); ->>>>>>> 8873ea16fd047997ae2bf766b85c98125077a74a ASSERT_EQ(sortedarr, arr); arr.clear(); sortedarr.clear(); @@ -94,11 +90,7 @@ TEST(TestKeyValueSort, KeyValueSort) } /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end(), compare); -<<<<<<< HEAD avx512_qsort_kv(keys.data(), values.data(), keys.size()); -======= - avx512_qsort(keys.data(), values.data(), keys.size()); ->>>>>>> 8873ea16fd047997ae2bf766b85c98125077a74a //ASSERT_EQ(sortedarr, arr); for (size_t i = 0; i < keys.size(); i++) { ASSERT_EQ(keys[i], sortedarr[i].key); From e73aa3f4ba55f626b5b87e4f9d5ef0ab2b553b00 Mon Sep 17 00:00:00 2001 From: ruclz Date: Thu, 2 Mar 2023 12:34:29 +0800 Subject: [PATCH 13/16] Delete duplicated code and add benchmarking code of key value sort --- benchmarks/bench-tgl.out | 57 +- benchmarks/bench.hpp | 53 ++ benchmarks/main.cpp | 69 ++- src/avx512-16bit-qsort.hpp | 4 +- src/avx512-32bit-qsort.hpp | 6 +- src/avx512-64bit-common.h | 351 +++++++++++++ src/avx512-64bit-keyvaluesort.hpp | 829 +++++++++--------------------- src/avx512-64bit-qsort.hpp | 336 +----------- src/avx512-common-keyvaluesort.h | 63 +-- tests/meson.build | 30 +- tests/test_all.cpp | 41 +- utils/rand_array.h | 42 +- 12 files changed, 819 insertions(+), 1062 deletions(-) create mode 100644 src/avx512-64bit-common.h diff --git a/benchmarks/bench-tgl.out b/benchmarks/bench-tgl.out index a1e0bcf2..9ebb3dc0 100644 --- a/benchmarks/bench-tgl.out +++ b/benchmarks/bench-tgl.out @@ -1,28 +1,29 @@ -|-----------------+-------------+------------+-----------------+-----------+----------| -| Array data type | typeid name | array size | avx512_qsort | std::sort | speed up | -|-----------------+-------------+------------+-----------------+-----------+----------| -| uniform random | uint32_t | 10000 | 115697 | 1579118 | 13.6 | -| uniform random | uint32_t | 100000 | 1786812 | 19973203 | 11.2 | -| uniform random | uint32_t | 1000000 | 22536966 | 233470422 | 10.4 | -| uniform random | int32_t | 10000 | 95591 | 1569108 | 16.4 | -| uniform random | int32_t | 100000 | 1790362 | 19785007 | 11.1 | -| uniform random | int32_t | 1000000 | 22874571 | 233358497 | 10.2 | -| uniform random | float | 10000 | 113316 | 1668407 | 14.7 | -| uniform random | float | 100000 | 1920018 | 21815024 | 11.4 | -| uniform random | float | 1000000 | 24776954 | 256867990 | 10.4 | -| uniform random | uint64_t | 10000 | 233501 | 1537649 | 6.6 | -| uniform random | uint64_t | 100000 | 3991372 | 19559859 | 4.9 | -| uniform random | uint64_t | 1000000 | 49818870 | 232687666 | 4.7 | -| uniform random | int64_t | 10000 | 228000 | 1445131 | 6.3 | -| uniform random | int64_t | 100000 | 3892092 | 18917322 | 4.9 | -| uniform random | int64_t | 1000000 | 48957088 | 235100259 | 4.8 | -| uniform random | double | 10000 | 180307 | 1702801 | 9.4 | -| uniform random | double | 100000 | 3596886 | 21849587 | 6.1 | -| uniform random | double | 1000000 | 47724381 | 258014177 | 5.4 | -| uniform random | uint16_t | 10000 | 84732 | 1548275 | 18.3 | -| uniform random | uint16_t | 100000 | 1406417 | 19632858 | 14.0 | -| uniform random | uint16_t | 1000000 | 17119960 | 214085305 | 12.5 | -| uniform random | int16_t | 10000 | 84703 | 1547726 | 18.3 | -| uniform random | int16_t | 100000 | 1442726 | 19705242 | 13.7 | -| uniform random | int16_t | 1000000 | 20210224 | 212137465 | 10.5 | -|-----------------+-------------+------------+-----------------+-----------+----------| +| -----------------+-------------+------------+-----------------+-----------+---------- | + | Array data type | typeid name | array size + | avx512_qsort | std::sort | speed up | + | -----------------+-------------+------------+-----------------+-----------+---------- | + | uniform random | uint32_t | 10000 | 115697 | 1579118 | 13.6 | + | uniform random | uint32_t | 100000 | 1786812 | 19973203 | 11.2 | + | uniform random | uint32_t | 1000000 | 22536966 | 233470422 | 10.4 | + | uniform random | int32_t | 10000 | 95591 | 1569108 | 16.4 | + | uniform random | int32_t | 100000 | 1790362 | 19785007 | 11.1 | + | uniform random | int32_t | 1000000 | 22874571 | 233358497 | 10.2 | + | uniform random | float | 10000 | 113316 | 1668407 | 14.7 | + | uniform random | float | 100000 | 1920018 | 21815024 | 11.4 | + | uniform random | float | 1000000 | 24776954 | 256867990 | 10.4 | + | uniform random | uint64_t | 10000 | 233501 | 1537649 | 6.6 | + | uniform random | uint64_t | 100000 | 3991372 | 19559859 | 4.9 | + | uniform random | uint64_t | 1000000 | 49818870 | 232687666 | 4.7 | + | uniform random | int64_t | 10000 | 228000 | 1445131 | 6.3 | + | uniform random | int64_t | 100000 | 3892092 | 18917322 | 4.9 | + | uniform random | int64_t | 1000000 | 48957088 | 235100259 | 4.8 | + | uniform random | double | 10000 | 180307 | 1702801 | 9.4 | + | uniform random | double | 100000 | 3596886 | 21849587 | 6.1 | + | uniform random | double | 1000000 | 47724381 | 258014177 | 5.4 | + | uniform random | uint16_t | 10000 | 84732 | 1548275 | 18.3 | + | uniform random | uint16_t | 100000 | 1406417 | 19632858 | 14.0 | + | uniform random | uint16_t | 1000000 | 17119960 | 214085305 | 12.5 | + | uniform random | int16_t | 10000 | 84703 | 1547726 | 18.3 | + | uniform random | int16_t | 100000 | 1442726 | 19705242 | 13.7 | + | uniform random | int16_t | 1000000 | 20210224 | 212137465 | 10.5 | + | -----------------+-------------+------------+-----------------+-----------+---------- | diff --git a/benchmarks/bench.hpp b/benchmarks/bench.hpp index 0837a6ca..073c7a39 100644 --- a/benchmarks/bench.hpp +++ b/benchmarks/bench.hpp @@ -5,12 +5,19 @@ #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-keyvaluesort.hpp" #include "avx512-64bit-qsort.hpp" #include #include #include #include +template +struct sorted_t { + K key; + V value; +}; + static inline uint64_t cycles_start(void) { unsigned a, d; @@ -72,3 +79,49 @@ std::tuple bench_sort(const std::vector arr, / lastfew; return std::make_tuple(avx_sort, std_sort); } +template +std::tuple +bench_sort_kv(const std::vector keys, + const std::vector values, + const std::vector> sortedaar, + const uint64_t iters, + const uint64_t lastfew) +{ + + std::vector keys_bckup = keys; + std::vector values_bckup = values; + std::vector> sortedaar_bckup = sortedaar; + + std::vector runtimes1, runtimes2; + uint64_t start(0), end(0); + for (uint64_t ii = 0; ii < iters; ++ii) { + start = cycles_start(); + avx512_qsort_kv( + keys_bckup.data(), values_bckup.data(), keys_bckup.size()); + end = cycles_end(); + runtimes1.emplace_back(end - start); + keys_bckup = keys; + values_bckup = values; + } + uint64_t avx_sort = std::accumulate(runtimes1.end() - lastfew, + runtimes1.end(), + (uint64_t)0) + / lastfew; + + for (uint64_t ii = 0; ii < iters; ++ii) { + start = cycles_start(); + std::sort(sortedaar_bckup.begin(), + sortedaar_bckup.end(), + [](sorted_t a, sorted_t b) { + return a.key < b.key; + }); + end = cycles_end(); + runtimes2.emplace_back(end - start); + sortedaar_bckup = sortedaar; + } + uint64_t std_sort = std::accumulate(runtimes2.end() - lastfew, + runtimes2.end(), + (uint64_t)0) + / lastfew; + return std::make_tuple(avx_sort, std_sort); +} diff --git a/benchmarks/main.cpp b/benchmarks/main.cpp index 5340b881..b8cf95bb 100644 --- a/benchmarks/main.cpp +++ b/benchmarks/main.cpp @@ -22,7 +22,7 @@ template +void run_bench_kv(const std::string datatype) +{ + std::streamsize ss = std::cout.precision(); + std::cout << std::fixed; + std::cout << std::setprecision(1); + std::vector array_sizes = {10000, 100000, 1000000}; + for (auto size : array_sizes) { + std::vector keys; + std::vector values; + std::vector> sortedarr; + + if (datatype.find("kv_uniform") != std::string::npos) { + keys = get_uniform_rand_array(size); + } + else if (datatype.find("kv_reverse") != std::string::npos) { + for (int ii = 0; ii < size; ++ii) { + //arr.emplace_back((T)(size - ii)); + keys.emplace_back((K)(size - ii)); + } + } + else if (datatype.find("kv_ordered") != std::string::npos) { + for (int ii = 0; ii < size; ++ii) { + keys.emplace_back((ii)); + } + } + else if (datatype.find("kv_limited") != std::string::npos) { + keys = get_uniform_rand_array(size, (K)10, (K)0); + } + else { + std::cout << "Skipping unrecognized array type: " << datatype + << std::endl; + return; + } + values = get_uniform_rand_array(size); + for (size_t i = 0; i < keys.size(); i++) { + sorted_t tmp_s; + tmp_s.key = keys[i]; + tmp_s.value = values[i]; + sortedarr.emplace_back(tmp_s); + } + + auto out = bench_sort_kv(keys, values, sortedarr, 20, 10); + printLine(' ', + datatype, + typeid(K).name(), + sizeof(K), + size, + std::get<0>(out), + std::get<1>(out), + (float)std::get<1>(out) / std::get<0>(out)); + } + std::cout << std::setprecision(ss); +} void bench_all(const std::string datatype) { if (cpu_has_avx512bw()) { @@ -97,7 +151,15 @@ void bench_all(const std::string datatype) } } } +void bench_all_kv(const std::string datatype) +{ + if (cpu_has_avx512bw()) { + run_bench_kv(datatype); + run_bench_kv(datatype); + run_bench_kv(datatype); + } +} int main(/*int argc, char *argv[]*/) { printLine(' ', @@ -113,6 +175,11 @@ int main(/*int argc, char *argv[]*/) bench_all("reverse"); bench_all("ordered"); bench_all("limitedrange"); + + bench_all_kv("kv_uniform random"); + bench_all_kv("kv_reverse"); + bench_all_kv("kv_ordered"); + bench_all_kv("kv_limitedrange"); printLine('-', "", "", "", "", "", "", ""); return 0; } diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index b2b4cb1c..b7130e2f 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -686,7 +686,7 @@ replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int16_t *arr, int64_t arrsize) +void avx512_qsort(int16_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, int16_t>( @@ -695,7 +695,7 @@ inline void avx512_qsort(int16_t *arr, int64_t arrsize) } template <> -inline void avx512_qsort(uint16_t *arr, int64_t arrsize) +void avx512_qsort(uint16_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_16bit_, uint16_t>( diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 457df984..1cbba00b 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -682,7 +682,7 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -inline void avx512_qsort(int32_t *arr, int64_t arrsize) +void avx512_qsort(int32_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, int32_t>( @@ -691,7 +691,7 @@ inline void avx512_qsort(int32_t *arr, int64_t arrsize) } template <> -inline void avx512_qsort(uint32_t *arr, int64_t arrsize) +void avx512_qsort(uint32_t *arr, int64_t arrsize) { if (arrsize > 1) { qsort_32bit_, uint32_t>( @@ -700,7 +700,7 @@ inline void avx512_qsort(uint32_t *arr, int64_t arrsize) } template <> -inline void avx512_qsort(float *arr, int64_t arrsize) +void avx512_qsort(float *arr, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(arr, arrsize); diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h new file mode 100644 index 00000000..7c5ddc5a --- /dev/null +++ b/src/avx512-64bit-common.h @@ -0,0 +1,351 @@ +#ifndef AVX512_64BIT_COMMOM +#define AVX512_64BIT_COMMOM +#include "avx512-common-qsort.h" + +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + +template <> +struct zmm_vector { + using type_t = int64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT64; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT64; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } // TODO: this should broadcast bits as is? + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epi64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epi64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epi64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epi64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = uint64_t; + using zmm_t = __m512i; + using ymm_t = __m512i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT64; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm512_set1_epi64(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); + } + + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi64(index, base, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_si512(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_epu64(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_epi64(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_epi64(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_epi64(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_epi64(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_epu64(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_epi64(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_epu64(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_epu64(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_epi64(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + __m512d temp = _mm512_castsi512_pd(zmm); + return _mm512_castpd_si512( + _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_si512(mem, x); + } +}; +template <> +struct zmm_vector { + using type_t = double; + using zmm_t = __m512d; + using ymm_t = __m512d; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITY; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITY; + } + static zmm_t zmm_max() + { + return _mm512_set1_pd(type_max()); + } + + static zmm_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4, + type_t v5, + type_t v6, + type_t v7, + type_t v8) + { + return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_pd(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm512_loadu_pd(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm512_max_pd(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_compressstoreu_pd(mem, mask, x); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm512_mask_loadu_pd(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm512_mask_mov_pd(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm512_mask_storeu_pd(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm512_min_pd(x, y); + } + static zmm_t permutexvar(__m512i idx, zmm_t zmm) + { + return _mm512_permutexvar_pd(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + return _mm512_reduce_max_pd(v); + } + static type_t reducemin(zmm_t v) + { + return _mm512_reduce_min_pd(v); + } + static zmm_t set1(type_t v) + { + return _mm512_set1_pd(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); + } + static void storeu(void *mem, zmm_t x) + { + return _mm512_storeu_pd(mem, x); + } +}; +X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask8 loadmask = 0xFF; + while (arrsize > 0) { + if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } + __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); + __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); + nan_count += _mm_popcnt_u32((int32_t)nanmask); + _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); + arr += 8; + arrsize -= 8; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) +{ + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + arr[ii] = std::nan("1"); + nan_count -= 1; + } +} + +#endif \ No newline at end of file diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 49cedca6..39781884 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -9,384 +9,51 @@ #include "avx512-common-keyvaluesort.h" -/* - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - -template <> -struct zmm_kv_vector { - using type_t = int64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT64; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT64; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_kv_vector { - using type_t = uint64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT64; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_kv_vector { - using type_t = double; - using zmm_t = __m512d; - using ymm_t = __m512d; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITY; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITY; - } - static zmm_t zmm_max() - { - return _mm512_set1_pd(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t eq(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OS); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_pd(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_pd(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_pd(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_pd(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_pd(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_pd(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_pd(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_pd(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_pd(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_pd(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_pd(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_pd(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_pd(mem, x); - } -}; - -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ - template ::zmm_t> -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_t &index_zmm) + typename index_type = zmm_vector::zmm_t> +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); key_zmm = cmp_merge( key_zmm, vtype::template shuffle(key_zmm), index_zmm, - zmm_kv_vector::template shuffle( + zmm_vector::template shuffle( index_zmm), 0xAA); key_zmm = cmp_merge( key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), index_zmm, - zmm_kv_vector::permutexvar( - _mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), + index_zmm), 0xCC); key_zmm = cmp_merge( key_zmm, vtype::template shuffle(key_zmm), index_zmm, - zmm_kv_vector::template shuffle( + zmm_vector::template shuffle( index_zmm), 0xAA); key_zmm = cmp_merge( key_zmm, vtype::permutexvar(rev_index, key_zmm), index_zmm, - zmm_kv_vector::permutexvar(rev_index, index_zmm), + zmm_vector::permutexvar(rev_index, index_zmm), 0xF0); key_zmm = cmp_merge( key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_kv_vector::permutexvar( - _mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), 0xCC); key_zmm = cmp_merge( key_zmm, vtype::template shuffle(key_zmm), index_zmm, - zmm_kv_vector::template shuffle( + zmm_vector::template shuffle( index_zmm), 0xAA); return key_zmm; @@ -394,9 +61,9 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_t &index_zmm) // Assumes zmm is bitonic and performs a recursive half cleaner template ::zmm_t> -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit( - zmm_t key_zmm, zmm_kv_vector::zmm_t &index_zmm) + typename index_type = zmm_vector::zmm_t> +X86_SIMD_SORT_INLINE zmm_t +bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) { // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 @@ -404,23 +71,23 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit( key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), index_zmm, - zmm_kv_vector::permutexvar( - _mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), + index_zmm), 0xF0); // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_kv_vector::permutexvar( - _mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), + index_zmm), 0xCC); // 3) half_cleaner[1] key_zmm = cmp_merge( key_zmm, vtype::template shuffle(key_zmm), index_zmm, - zmm_kv_vector::template shuffle( + zmm_vector::template shuffle( index_zmm), 0xAA); return key_zmm; @@ -428,23 +95,23 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit( // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner template ::zmm_t> + typename index_type = zmm_vector::zmm_t> X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, zmm_t &key_zmm2, - index_t &index_zmm1, - index_t &index_zmm2) + index_type &index_zmm1, + index_type &index_zmm2) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network: coex of zmm1 and zmm2 reversed key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); - index_zmm2 = zmm_kv_vector::permutexvar(rev_index, index_zmm2); + index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); - index_t index_zmm3 = zmm_kv_vector::mask_mov( + index_type index_zmm3 = zmm_vector::mask_mov( index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); - index_t index_zmm4 = zmm_kv_vector::mask_mov( + index_type index_zmm4 = zmm_vector::mask_mov( index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); // 2) Recursive half cleaner for each @@ -457,53 +124,53 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, // half cleaner template ::zmm_t> + typename index_type = zmm_vector::zmm_t> X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, - index_t *index_zmm) + index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); - index_t index_zmm2r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[2]); - index_t index_zmm3r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[3]); + index_type index_zmm2r + = zmm_vector::permutexvar(rev_index, index_zmm[2]); + index_type index_zmm3r + = zmm_vector::permutexvar(rev_index, index_zmm[3]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); - index_t index_zmm_t1 = zmm_kv_vector::mask_mov( + index_type index_zmm_t1 = zmm_vector::mask_mov( index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_t index_zmm_m1 = zmm_kv_vector::mask_mov( + index_type index_zmm_m1 = zmm_vector::mask_mov( index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); - index_t index_zmm_t2 = zmm_kv_vector::mask_mov( + index_type index_zmm_t2 = zmm_vector::mask_mov( index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_t index_zmm_m2 = zmm_kv_vector::mask_mov( + index_type index_zmm_m2 = zmm_vector::mask_mov( index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); // 2) Recursive half clearer: 16 zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t3 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t4 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m1); + index_type index_zmm_t3 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t4 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); - index_t index_zmm0 = zmm_kv_vector::mask_mov( + index_type index_zmm0 = zmm_vector::mask_mov( index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); - index_t index_zmm1 = zmm_kv_vector::mask_mov( + index_type index_zmm1 = zmm_vector::mask_mov( index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); - index_t index_zmm2 = zmm_kv_vector::mask_mov( + index_type index_zmm2 = zmm_vector::mask_mov( index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); - index_t index_zmm3 = zmm_kv_vector::mask_mov( + index_type index_zmm3 = zmm_vector::mask_mov( index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); @@ -518,23 +185,23 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, } template ::zmm_t> + typename index_type = zmm_vector::zmm_t> X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, - index_t *index_zmm) + index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); - index_t index_zmm4r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[4]); - index_t index_zmm5r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[5]); - index_t index_zmm6r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[6]); - index_t index_zmm7r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[7]); + index_type index_zmm4r + = zmm_vector::permutexvar(rev_index, index_zmm[4]); + index_type index_zmm5r + = zmm_vector::permutexvar(rev_index, index_zmm[5]); + index_type index_zmm6r + = zmm_vector::permutexvar(rev_index, index_zmm[6]); + index_type index_zmm7r + = zmm_vector::permutexvar(rev_index, index_zmm[7]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); @@ -546,35 +213,35 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); - index_t index_zmm_t1 = zmm_kv_vector::mask_mov( + index_type index_zmm_t1 = zmm_vector::mask_mov( index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_t index_zmm_m1 = zmm_kv_vector::mask_mov( + index_type index_zmm_m1 = zmm_vector::mask_mov( index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); - index_t index_zmm_t2 = zmm_kv_vector::mask_mov( + index_type index_zmm_t2 = zmm_vector::mask_mov( index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_t index_zmm_m2 = zmm_kv_vector::mask_mov( + index_type index_zmm_m2 = zmm_vector::mask_mov( index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); - index_t index_zmm_t3 = zmm_kv_vector::mask_mov( + index_type index_zmm_t3 = zmm_vector::mask_mov( index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_t index_zmm_m3 = zmm_kv_vector::mask_mov( + index_type index_zmm_m3 = zmm_vector::mask_mov( index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); - index_t index_zmm_t4 = zmm_kv_vector::mask_mov( + index_type index_zmm_t4 = zmm_vector::mask_mov( index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_t index_zmm_m4 = zmm_kv_vector::mask_mov( + index_type index_zmm_m4 = zmm_vector::mask_mov( index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t5 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t6 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t7 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t8 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m1); + index_type index_zmm_t5 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t6 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t7 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t8 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); @@ -604,9 +271,9 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, } template ::zmm_t> + typename index_type = zmm_vector::zmm_t> X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, - index_t *index_zmm) + index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); @@ -618,22 +285,22 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); - index_t index_zmm8r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[8]); - index_t index_zmm9r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[9]); - index_t index_zmm10r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[10]); - index_t index_zmm11r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[11]); - index_t index_zmm12r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[12]); - index_t index_zmm13r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[13]); - index_t index_zmm14r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[14]); - index_t index_zmm15r - = zmm_kv_vector::permutexvar(rev_index, index_zmm[15]); + index_type index_zmm8r + = zmm_vector::permutexvar(rev_index, index_zmm[8]); + index_type index_zmm9r + = zmm_vector::permutexvar(rev_index, index_zmm[9]); + index_type index_zmm10r + = zmm_vector::permutexvar(rev_index, index_zmm[10]); + index_type index_zmm11r + = zmm_vector::permutexvar(rev_index, index_zmm[11]); + index_type index_zmm12r + = zmm_vector::permutexvar(rev_index, index_zmm[12]); + index_type index_zmm13r + = zmm_vector::permutexvar(rev_index, index_zmm[13]); + index_type index_zmm14r + = zmm_vector::permutexvar(rev_index, index_zmm[14]); + index_type index_zmm15r + = zmm_vector::permutexvar(rev_index, index_zmm[15]); zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); @@ -653,38 +320,38 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); - index_t index_zmm_t1 = zmm_kv_vector::mask_mov( + index_type index_zmm_t1 = zmm_vector::mask_mov( index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_t index_zmm_m1 = zmm_kv_vector::mask_mov( + index_type index_zmm_m1 = zmm_vector::mask_mov( index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); - index_t index_zmm_t2 = zmm_kv_vector::mask_mov( + index_type index_zmm_t2 = zmm_vector::mask_mov( index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_t index_zmm_m2 = zmm_kv_vector::mask_mov( + index_type index_zmm_m2 = zmm_vector::mask_mov( index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); - index_t index_zmm_t3 = zmm_kv_vector::mask_mov( + index_type index_zmm_t3 = zmm_vector::mask_mov( index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_t index_zmm_m3 = zmm_kv_vector::mask_mov( + index_type index_zmm_m3 = zmm_vector::mask_mov( index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); - index_t index_zmm_t4 = zmm_kv_vector::mask_mov( + index_type index_zmm_t4 = zmm_vector::mask_mov( index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_t index_zmm_m4 = zmm_kv_vector::mask_mov( + index_type index_zmm_m4 = zmm_vector::mask_mov( index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); - index_t index_zmm_t5 = zmm_kv_vector::mask_mov( + index_type index_zmm_t5 = zmm_vector::mask_mov( index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); - index_t index_zmm_m5 = zmm_kv_vector::mask_mov( + index_type index_zmm_m5 = zmm_vector::mask_mov( index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); - index_t index_zmm_t6 = zmm_kv_vector::mask_mov( + index_type index_zmm_t6 = zmm_vector::mask_mov( index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); - index_t index_zmm_m6 = zmm_kv_vector::mask_mov( + index_type index_zmm_m6 = zmm_vector::mask_mov( index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); - index_t index_zmm_t7 = zmm_kv_vector::mask_mov( + index_type index_zmm_t7 = zmm_vector::mask_mov( index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); - index_t index_zmm_m7 = zmm_kv_vector::mask_mov( + index_type index_zmm_m7 = zmm_vector::mask_mov( index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); - index_t index_zmm_t8 = zmm_kv_vector::mask_mov( + index_type index_zmm_t8 = zmm_vector::mask_mov( index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); - index_t index_zmm_m8 = zmm_kv_vector::mask_mov( + index_type index_zmm_m8 = zmm_vector::mask_mov( index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); @@ -695,22 +362,22 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); - index_t index_zmm_t9 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m8); - index_t index_zmm_t10 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m7); - index_t index_zmm_t11 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m6); - index_t index_zmm_t12 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m5); - index_t index_zmm_t13 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m4); - index_t index_zmm_t14 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m3); - index_t index_zmm_t15 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m2); - index_t index_zmm_t16 - = zmm_kv_vector::permutexvar(rev_index, index_zmm_m1); + index_type index_zmm_t9 + = zmm_vector::permutexvar(rev_index, index_zmm_m8); + index_type index_zmm_t10 + = zmm_vector::permutexvar(rev_index, index_zmm_m7); + index_type index_zmm_t11 + = zmm_vector::permutexvar(rev_index, index_zmm_m6); + index_type index_zmm_t12 + = zmm_vector::permutexvar(rev_index, index_zmm_m5); + index_type index_zmm_t13 + = zmm_vector::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t14 + = zmm_vector::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t15 + = zmm_vector::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t16 + = zmm_vector::permutexvar(rev_index, index_zmm_m1); COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); @@ -781,12 +448,11 @@ sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) typename vtype::zmm_t key_zmm = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); - zmm_kv_vector::zmm_t index_zmm - = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask, indexes); + zmm_vector::zmm_t index_zmm = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes); vtype::mask_storeu( keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); - zmm_kv_vector::mask_storeu(indexes, load_mask, index_zmm); + zmm_vector::mask_storeu(indexes, load_mask, index_zmm); } template @@ -798,24 +464,24 @@ sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) return; } using zmm_t = typename vtype::zmm_t; - using index_t = zmm_kv_vector::zmm_t; + using index_type = zmm_vector::zmm_t; typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; zmm_t key_zmm1 = vtype::loadu(keys); zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); - index_t index_zmm1 = zmm_kv_vector::loadu(indexes); - index_t index_zmm2 = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask, indexes + 8); + index_type index_zmm1 = zmm_vector::loadu(indexes); + index_type index_zmm2 = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask, indexes + 8); key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); bitonic_merge_two_zmm_64bit( key_zmm1, key_zmm2, index_zmm1, index_zmm2); - zmm_kv_vector::storeu(indexes, index_zmm1); - zmm_kv_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); + zmm_vector::storeu(indexes, index_zmm1); + zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); vtype::storeu(keys, key_zmm1); vtype::mask_storeu(keys + 8, load_mask, key_zmm2); @@ -831,15 +497,15 @@ sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) } using zmm_t = typename vtype::zmm_t; using opmask_t = typename vtype::opmask_t; - using index_t = zmm_kv_vector::zmm_t; + using index_type = zmm_vector::zmm_t; zmm_t key_zmm[4]; - index_t index_zmm[4]; + index_type index_zmm[4]; key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); - index_zmm[0] = zmm_kv_vector::loadu(indexes); - index_zmm[1] = zmm_kv_vector::loadu(indexes + 8); + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); @@ -851,10 +517,10 @@ sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); - index_zmm[2] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask2, indexes + 24); + index_zmm[2] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 24); key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); @@ -865,12 +531,10 @@ sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - zmm_kv_vector::storeu(indexes, index_zmm[0]); - zmm_kv_vector::storeu(indexes + 8, index_zmm[1]); - zmm_kv_vector::mask_storeu( - indexes + 16, load_mask1, index_zmm[2]); - zmm_kv_vector::mask_storeu( - indexes + 24, load_mask2, index_zmm[3]); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); + zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); vtype::storeu(keys, key_zmm[0]); vtype::storeu(keys + 8, key_zmm[1]); @@ -888,19 +552,19 @@ sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) } using zmm_t = typename vtype::zmm_t; using opmask_t = typename vtype::opmask_t; - using index_t = zmm_kv_vector::zmm_t; + using index_type = zmm_vector::zmm_t; zmm_t key_zmm[8]; - index_t index_zmm[8]; + index_type index_zmm[8]; key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); key_zmm[2] = vtype::loadu(keys + 16); key_zmm[3] = vtype::loadu(keys + 24); - index_zmm[0] = zmm_kv_vector::loadu(indexes); - index_zmm[1] = zmm_kv_vector::loadu(indexes + 8); - index_zmm[2] = zmm_kv_vector::loadu(indexes + 16); - index_zmm[3] = zmm_kv_vector::loadu(indexes + 24); + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[2] = zmm_vector::loadu(indexes + 16); + index_zmm[3] = zmm_vector::loadu(indexes + 24); key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); @@ -919,14 +583,14 @@ sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); - index_zmm[4] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask4, indexes + 56); + index_zmm[4] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 56); key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); @@ -944,18 +608,14 @@ sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - zmm_kv_vector::storeu(indexes, index_zmm[0]); - zmm_kv_vector::storeu(indexes + 8, index_zmm[1]); - zmm_kv_vector::storeu(indexes + 16, index_zmm[2]); - zmm_kv_vector::storeu(indexes + 24, index_zmm[3]); - zmm_kv_vector::mask_storeu( - indexes + 32, load_mask1, index_zmm[4]); - zmm_kv_vector::mask_storeu( - indexes + 40, load_mask2, index_zmm[5]); - zmm_kv_vector::mask_storeu( - indexes + 48, load_mask3, index_zmm[6]); - zmm_kv_vector::mask_storeu( - indexes + 56, load_mask4, index_zmm[7]); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::storeu(indexes + 16, index_zmm[2]); + zmm_vector::storeu(indexes + 24, index_zmm[3]); + zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); + zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); + zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); + zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); vtype::storeu(keys, key_zmm[0]); vtype::storeu(keys + 8, key_zmm[1]); @@ -976,10 +636,10 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) return; } using zmm_t = typename vtype::zmm_t; - using index_t = zmm_kv_vector::zmm_t; + using index_type = zmm_vector::zmm_t; using opmask_t = typename vtype::opmask_t; zmm_t key_zmm[16]; - index_t index_zmm[16]; + index_type index_zmm[16]; key_zmm[0] = vtype::loadu(keys); key_zmm[1] = vtype::loadu(keys + 8); @@ -990,14 +650,14 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) key_zmm[6] = vtype::loadu(keys + 48); key_zmm[7] = vtype::loadu(keys + 56); - index_zmm[0] = zmm_kv_vector::loadu(indexes); - index_zmm[1] = zmm_kv_vector::loadu(indexes + 8); - index_zmm[2] = zmm_kv_vector::loadu(indexes + 16); - index_zmm[3] = zmm_kv_vector::loadu(indexes + 24); - index_zmm[4] = zmm_kv_vector::loadu(indexes + 32); - index_zmm[5] = zmm_kv_vector::loadu(indexes + 40); - index_zmm[6] = zmm_kv_vector::loadu(indexes + 48); - index_zmm[7] = zmm_kv_vector::loadu(indexes + 56); + index_zmm[0] = zmm_vector::loadu(indexes); + index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[2] = zmm_vector::loadu(indexes + 16); + index_zmm[3] = zmm_vector::loadu(indexes + 24); + index_zmm[4] = zmm_vector::loadu(indexes + 32); + index_zmm[5] = zmm_vector::loadu(indexes + 40); + index_zmm[6] = zmm_vector::loadu(indexes + 48); + index_zmm[7] = zmm_vector::loadu(indexes + 56); key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); @@ -1031,22 +691,22 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); - index_zmm[8] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = zmm_kv_vector::mask_loadu( - zmm_kv_vector::zmm_max(), load_mask8, indexes + 120); + index_zmm[8] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] = zmm_vector::mask_loadu( + zmm_vector::zmm_max(), load_mask8, indexes + 120); key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); @@ -1079,30 +739,22 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - zmm_kv_vector::storeu(indexes, index_zmm[0]); - zmm_kv_vector::storeu(indexes + 8, index_zmm[1]); - zmm_kv_vector::storeu(indexes + 16, index_zmm[2]); - zmm_kv_vector::storeu(indexes + 24, index_zmm[3]); - zmm_kv_vector::storeu(indexes + 32, index_zmm[4]); - zmm_kv_vector::storeu(indexes + 40, index_zmm[5]); - zmm_kv_vector::storeu(indexes + 48, index_zmm[6]); - zmm_kv_vector::storeu(indexes + 56, index_zmm[7]); - zmm_kv_vector::mask_storeu( - indexes + 64, load_mask1, index_zmm[8]); - zmm_kv_vector::mask_storeu( - indexes + 72, load_mask2, index_zmm[9]); - zmm_kv_vector::mask_storeu( - indexes + 80, load_mask3, index_zmm[10]); - zmm_kv_vector::mask_storeu( - indexes + 88, load_mask4, index_zmm[11]); - zmm_kv_vector::mask_storeu( - indexes + 96, load_mask5, index_zmm[12]); - zmm_kv_vector::mask_storeu( - indexes + 104, load_mask6, index_zmm[13]); - zmm_kv_vector::mask_storeu( - indexes + 112, load_mask7, index_zmm[14]); - zmm_kv_vector::mask_storeu( - indexes + 120, load_mask8, index_zmm[15]); + zmm_vector::storeu(indexes, index_zmm[0]); + zmm_vector::storeu(indexes + 8, index_zmm[1]); + zmm_vector::storeu(indexes + 16, index_zmm[2]); + zmm_vector::storeu(indexes + 24, index_zmm[3]); + zmm_vector::storeu(indexes + 32, index_zmm[4]); + zmm_vector::storeu(indexes + 40, index_zmm[5]); + zmm_vector::storeu(indexes + 48, index_zmm[6]); + zmm_vector::storeu(indexes + 56, index_zmm[7]); + zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); + zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); + zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); + zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); + zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); + zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); + zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); + zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); vtype::storeu(keys, key_zmm[0]); vtype::storeu(keys + 8, key_zmm[1]); @@ -1131,7 +783,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, // median of 8 int64_t size = (right - left) / 8; using zmm_t = typename vtype::zmm_t; - using index_t = zmm_kv_vector::zmm_t; + using index_type = zmm_vector::zmm_t; __m512i rand_index = _mm512_set_epi64(left + size, left + 2 * size, left + 3 * size, @@ -1142,9 +794,9 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, left + 8 * size); zmm_t key_vec = vtype::template i64gather(rand_index, keys); - index_t index_vec; + index_type index_vec; zmm_t sort; - index_vec = zmm_kv_vector::template i64gather( + index_vec = zmm_vector::template i64gather( rand_index, indexes); sort = sort_zmm_64bit(key_vec, index_vec); // pivot will never be a nan, since there are no nan's! @@ -1153,7 +805,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, } template -inline void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) +void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) { int64_t i = idx; while (true) { @@ -1168,7 +820,7 @@ inline void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) } } template -inline void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) +void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) { for (int64_t i = size / 2 - 1; i >= 0; i--) { heapify(keys, indexes, i, size); @@ -1180,17 +832,23 @@ inline void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) } } +template +struct sortkv_t { + T key; + uint64_t value; +}; template -inline void qsort_64bit_(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - int64_t max_iters) +void qsort_64bit_(type_t *keys, + uint64_t *indexes, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { + //std::sort(keys+left,keys+right+1); heap_sort(keys + left, indexes + left, right - left + 1); return; } @@ -1209,68 +867,43 @@ inline void qsort_64bit_(type_t *keys, type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( keys, indexes, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) + if (pivot != smallest) { qsort_64bit_( keys, indexes, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); -} - -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf_kv(double *arr, - int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); - __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); - arr += 8; - arrsize -= 8; } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan_kv(double *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; + if (pivot != biggest) { + qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); } } template <> -inline void -avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, int64_t>( + qsort_64bit_, int64_t>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -inline void -avx512_qsort_kv(uint64_t *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort_kv(uint64_t *keys, + uint64_t *indexes, + int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, uint64_t>( + qsort_64bit_, uint64_t>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } template <> -inline void -avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) +void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(keys, arrsize); - qsort_64bit_, double>( + qsort_64bit_, double>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - replace_inf_with_nan_kv(keys, arrsize, nan_count); + replace_inf_with_nan(keys, arrsize, nan_count); } } #endif // AVX512_QSORT_64BIT_KV diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 30e43fb5..7ba9b52f 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -7,7 +7,7 @@ #ifndef AVX512_QSORT_64BIT #define AVX512_QSORT_64BIT -#include "avx512-common-qsort.h" +#include "avx512-64bit-common.h" /* * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic @@ -15,315 +15,6 @@ * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) */ // ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - -template <> -struct zmm_vector { - using type_t = int64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT64; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT64; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epi64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epi64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epi64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epi64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = uint64_t; - using zmm_t = __m512i; - using ymm_t = __m512i; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT64; - } - static type_t type_min() - { - return 0; - } - static zmm_t zmm_max() - { - return _mm512_set1_epi64(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); - } - - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_epi64(index, base, scale); - } - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_si512(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_epu64(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_epi64(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_epi64(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_epi64(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_epi64(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_epu64(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_epi64(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_epu64(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_epu64(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_epi64(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - __m512d temp = _mm512_castsi512_pd(zmm); - return _mm512_castpd_si512( - _mm512_shuffle_pd(temp, temp, (_MM_PERM_ENUM)mask)); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_si512(mem, x); - } -}; -template <> -struct zmm_vector { - using type_t = double; - using zmm_t = __m512d; - using ymm_t = __m512d; - using opmask_t = __mmask8; - static const uint8_t numlanes = 8; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITY; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITY; - } - static zmm_t zmm_max() - { - return _mm512_set1_pd(type_max()); - } - - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) - { - return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static opmask_t knot_opmask(opmask_t x) - { - return _knot_mask8(x); - } - static opmask_t ge(zmm_t x, zmm_t y) - { - return _mm512_cmp_pd_mask(x, y, _CMP_GE_OQ); - } - template - static zmm_t i64gather(__m512i index, void const *base) - { - return _mm512_i64gather_pd(index, base, scale); - } - static zmm_t loadu(void const *mem) - { - return _mm512_loadu_pd(mem); - } - static zmm_t max(zmm_t x, zmm_t y) - { - return _mm512_max_pd(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_compressstoreu_pd(mem, mask, x); - } - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) - { - return _mm512_mask_loadu_pd(x, mask, mem); - } - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) - { - return _mm512_mask_mov_pd(x, mask, y); - } - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) - { - return _mm512_mask_storeu_pd(mem, mask, x); - } - static zmm_t min(zmm_t x, zmm_t y) - { - return _mm512_min_pd(x, y); - } - static zmm_t permutexvar(__m512i idx, zmm_t zmm) - { - return _mm512_permutexvar_pd(idx, zmm); - } - static type_t reducemax(zmm_t v) - { - return _mm512_reduce_max_pd(v); - } - static type_t reducemin(zmm_t v) - { - return _mm512_reduce_min_pd(v); - } - static zmm_t set1(type_t v) - { - return _mm512_set1_pd(v); - } - template - static zmm_t shuffle(zmm_t zmm) - { - return _mm512_shuffle_pd(zmm, zmm, (_MM_PERM_ENUM)mask); - } - static void storeu(void *mem, zmm_t x) - { - return _mm512_storeu_pd(mem, x); - } -}; /* * Assumes zmm is random and performs a full sorting network defined in @@ -765,31 +456,6 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) qsort_64bit_(arr, pivot_index, right, max_iters - 1); } -X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize) -{ - int64_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr); - __mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ); - nan_count += _mm_popcnt_u32((int32_t)nanmask); - _mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE); - arr += 8; - arrsize -= 8; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) -{ - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; - } -} - template <> void avx512_qsort(int64_t *arr, int64_t arrsize) { diff --git a/src/avx512-common-keyvaluesort.h b/src/avx512-common-keyvaluesort.h index f4c642b3..114203e0 100644 --- a/src/avx512-common-keyvaluesort.h +++ b/src/avx512-common-keyvaluesort.h @@ -33,66 +33,17 @@ * */ -#include -#include -#include -#include -#include - -#define X86_SIMD_SORT_INFINITY std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYF std::numeric_limits::infinity() -#define X86_SIMD_SORT_INFINITYH 0x7c00 -#define X86_SIMD_SORT_NEGINFINITYH 0xfc00 -#define X86_SIMD_SORT_MAX_UINT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT16 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT16 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT32 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT32 std::numeric_limits::min() -#define X86_SIMD_SORT_MAX_UINT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MAX_INT64 std::numeric_limits::max() -#define X86_SIMD_SORT_MIN_INT64 std::numeric_limits::min() -#define ZMM_MAX_DOUBLE _mm512_set1_pd(X86_SIMD_SORT_INFINITY) -#define ZMM_MAX_UINT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_UINT64) -#define ZMM_MAX_INT64 _mm512_set1_epi64(X86_SIMD_SORT_MAX_INT64) -#define ZMM_MAX_FLOAT _mm512_set1_ps(X86_SIMD_SORT_INFINITYF) -#define ZMM_MAX_UINT _mm512_set1_epi32(X86_SIMD_SORT_MAX_UINT32) -#define ZMM_MAX_INT _mm512_set1_epi32(X86_SIMD_SORT_MAX_INT32) -#define YMM_MAX_HALF _mm256_set1_epi16(X86_SIMD_SORT_INFINITYH) -#define ZMM_MAX_UINT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_UINT16) -#define ZMM_MAX_INT16 _mm512_set1_epi16(X86_SIMD_SORT_MAX_INT16) -#define SHUFFLE_MASK(a, b, c, d) (a << 6) | (b << 4) | (c << 2) | d - -#ifdef _MSC_VER -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __forceinline -#elif defined(__CYGWIN__) -/* - * Force inline in cygwin to work around a compiler bug. See - * https://github.com/numpy/numpy/pull/22315#issuecomment-1267757584 - */ -#define X86_SIMD_SORT_INLINE static __attribute__((always_inline)) -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#elif defined(__GNUC__) -#define X86_SIMD_SORT_INLINE static inline -#define X86_SIMD_SORT_FINLINE static __attribute__((always_inline)) -#else -#define X86_SIMD_SORT_INLINE static -#define X86_SIMD_SORT_FINLINE static -#endif - -template -struct zmm_kv_vector; +#include "avx512-64bit-common.h" template -inline void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); +void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); using index_t = __m512i; -//using index_type = zmm_kv_vector; +//using index_type = zmm_vector; template > + typename index_type = zmm_vector> static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) { //COEX(key1,key2); @@ -112,7 +63,7 @@ static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) template > + typename index_type = zmm_vector> static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, index_t &indexes1, @@ -131,7 +82,7 @@ static inline zmm_t cmp_merge(zmm_t in1, template > + typename index_type = zmm_vector> static inline int32_t partition_vec(type_t *keys, uint64_t *indexes, int64_t left, @@ -163,7 +114,7 @@ static inline int32_t partition_vec(type_t *keys, */ template > + typename index_type = zmm_vector> static inline int64_t partition_avx512(type_t *keys, uint64_t *indexes, int64_t left, diff --git a/tests/meson.build b/tests/meson.build index 7d51ba26..40cd4685 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -1,19 +1,15 @@ libtests = [] -if cc.has_argument('-march=icelake-client') - libtests += static_library( - 'tests_', - files( - 'test_all.cpp', - ), - dependencies : gtest_dep, - include_directories : [ - src, - utils, - ], - cpp_args : [ - '-O3', - '-march=icelake-client', - ], - ) -endif + if cc.has_argument('-march=icelake-client') libtests + += static_library('tests_', files('test_all.cpp', ), dependencies + : gtest_dep, include_directories + : + [ + src, + utils, + ], + cpp_args + : [ + '-O3', + '-march=icelake-client', + ], ) endif diff --git a/tests/test_all.cpp b/tests/test_all.cpp index 35330fa8..ff65fff0 100644 --- a/tests/test_all.cpp +++ b/tests/test_all.cpp @@ -58,40 +58,48 @@ using Types = testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_sort, Types); +template struct sorted_t { - uint64_t key; - uint64_t value; + K key; + K value; }; - -bool compare(sorted_t a, sorted_t b) +template +bool compare(sorted_t a, sorted_t b) { return a.key == b.key ? a.value < b.value : a.key < b.key; } -TEST(TestKeyValueSort, KeyValueSort) + +template +class TestKeyValueSort : public ::testing::Test { +}; + +TYPED_TEST_SUITE_P(TestKeyValueSort); + +TYPED_TEST_P(TestKeyValueSort, KeyValueSort) { std::vector keysizes; for (int64_t ii = 0; ii < 1024; ++ii) { - keysizes.push_back((uint64_t)ii); + keysizes.push_back((TypeParam)ii); } - std::vector keys; + std::vector keys; std::vector values; - std::vector sortedarr; + std::vector> sortedarr; for (size_t ii = 0; ii < keysizes.size(); ++ii) { /* Random array */ - keys = get_uniform_rand_array_key(keysizes[ii]); - //keys = get_uniform_rand_array(keysizes[ii]); + keys = get_uniform_rand_array_key(keysizes[ii]); values = get_uniform_rand_array(keysizes[ii]); for (size_t i = 0; i < keys.size(); i++) { - sorted_t tmp_s; + sorted_t tmp_s; tmp_s.key = keys[i]; tmp_s.value = values[i]; sortedarr.emplace_back(tmp_s); } /* Sort with std::sort for comparison */ - std::sort(sortedarr.begin(), sortedarr.end(), compare); - avx512_qsort_kv(keys.data(), values.data(), keys.size()); - //ASSERT_EQ(sortedarr, arr); + std::sort(sortedarr.begin(), + sortedarr.end(), + compare); + avx512_qsort_kv(keys.data(), values.data(), keys.size()); for (size_t i = 0; i < keys.size(); i++) { ASSERT_EQ(keys[i], sortedarr[i].key); ASSERT_EQ(values[i], sortedarr[i].value); @@ -101,3 +109,8 @@ TEST(TestKeyValueSort, KeyValueSort) sortedarr.clear(); } } + +REGISTER_TYPED_TEST_SUITE_P(TestKeyValueSort, KeyValueSort); + +using TypesKv = testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefixKv, TestKeyValueSort, TypesKv); \ No newline at end of file diff --git a/utils/rand_array.h b/utils/rand_array.h index 42e0f99d..efa78881 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -42,20 +42,21 @@ static std::vector get_uniform_rand_array( } return arr; } - -static std::vector -get_uniform_rand_array_key(int64_t arrsize, - uint64_t max = std::numeric_limits::max(), - uint64_t min = std::numeric_limits::min()) +template +static std::vector get_uniform_rand_array_key( + int64_t arrsize, + T max = std::numeric_limits::max(), + T min = std::numeric_limits::min(), + typename std::enable_if::value>::type * = 0) { - std::vector arr; + std::vector arr; std::random_device r; std::default_random_engine e1(r()); - std::uniform_int_distribution uniform_dist(min, max); + std::uniform_int_distribution uniform_dist(min, max); for (int64_t ii = 0; ii < arrsize; ++ii) { while (true) { - uint64_t tmp = uniform_dist(e1); + T tmp = uniform_dist(e1); auto iter = std::find(arr.begin(), arr.end(), tmp); if (iter == arr.end()) { arr.emplace_back(tmp); @@ -65,3 +66,28 @@ get_uniform_rand_array_key(int64_t arrsize, } return arr; } + +template +static std::vector get_uniform_rand_array_key( + int64_t arrsize, + T max = std::numeric_limits::max(), + T min = std::numeric_limits::min(), + typename std::enable_if::value>::type * = 0) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(min, max); + std::vector arr; + for (int64_t ii = 0; ii < arrsize; ++ii) { + + while (true) { + T tmp = dis(gen); + auto iter = std::find(arr.begin(), arr.end(), tmp); + if (iter == arr.end()) { + arr.emplace_back(tmp); + break; + } + } + } + return arr; +} \ No newline at end of file From b73ba45e1e127dcd6a1d43b08162c2f8cb35f798 Mon Sep 17 00:00:00 2001 From: ruclz Date: Thu, 9 Mar 2023 15:57:42 +0800 Subject: [PATCH 14/16] Move the get_povit_64bit function and sort_zmm_64bit function to avx512-64bit-common.h, modify the bench-tgl.out to origin file. --- benchmarks/bench-tgl.out | 57 +++++++++++++++---------------- src/avx512-64bit-common.h | 47 +++++++++++++++++++++++++ src/avx512-64bit-keyvaluesort.hpp | 32 +---------------- src/avx512-64bit-qsort.hpp | 48 -------------------------- 4 files changed, 76 insertions(+), 108 deletions(-) diff --git a/benchmarks/bench-tgl.out b/benchmarks/bench-tgl.out index 9ebb3dc0..1bb03936 100644 --- a/benchmarks/bench-tgl.out +++ b/benchmarks/bench-tgl.out @@ -1,29 +1,28 @@ -| -----------------+-------------+------------+-----------------+-----------+---------- | - | Array data type | typeid name | array size - | avx512_qsort | std::sort | speed up | - | -----------------+-------------+------------+-----------------+-----------+---------- | - | uniform random | uint32_t | 10000 | 115697 | 1579118 | 13.6 | - | uniform random | uint32_t | 100000 | 1786812 | 19973203 | 11.2 | - | uniform random | uint32_t | 1000000 | 22536966 | 233470422 | 10.4 | - | uniform random | int32_t | 10000 | 95591 | 1569108 | 16.4 | - | uniform random | int32_t | 100000 | 1790362 | 19785007 | 11.1 | - | uniform random | int32_t | 1000000 | 22874571 | 233358497 | 10.2 | - | uniform random | float | 10000 | 113316 | 1668407 | 14.7 | - | uniform random | float | 100000 | 1920018 | 21815024 | 11.4 | - | uniform random | float | 1000000 | 24776954 | 256867990 | 10.4 | - | uniform random | uint64_t | 10000 | 233501 | 1537649 | 6.6 | - | uniform random | uint64_t | 100000 | 3991372 | 19559859 | 4.9 | - | uniform random | uint64_t | 1000000 | 49818870 | 232687666 | 4.7 | - | uniform random | int64_t | 10000 | 228000 | 1445131 | 6.3 | - | uniform random | int64_t | 100000 | 3892092 | 18917322 | 4.9 | - | uniform random | int64_t | 1000000 | 48957088 | 235100259 | 4.8 | - | uniform random | double | 10000 | 180307 | 1702801 | 9.4 | - | uniform random | double | 100000 | 3596886 | 21849587 | 6.1 | - | uniform random | double | 1000000 | 47724381 | 258014177 | 5.4 | - | uniform random | uint16_t | 10000 | 84732 | 1548275 | 18.3 | - | uniform random | uint16_t | 100000 | 1406417 | 19632858 | 14.0 | - | uniform random | uint16_t | 1000000 | 17119960 | 214085305 | 12.5 | - | uniform random | int16_t | 10000 | 84703 | 1547726 | 18.3 | - | uniform random | int16_t | 100000 | 1442726 | 19705242 | 13.7 | - | uniform random | int16_t | 1000000 | 20210224 | 212137465 | 10.5 | - | -----------------+-------------+------------+-----------------+-----------+---------- | +|-----------------+-------------+------------+-----------------+-----------+----------| +| Array data type | typeid name | array size | avx512_qsort | std::sort | speed up | +|-----------------+-------------+------------+-----------------+-----------+----------| +| uniform random | uint32_t | 10000 | 115697 | 1579118 | 13.6 | +| uniform random | uint32_t | 100000 | 1786812 | 19973203 | 11.2 | +| uniform random | uint32_t | 1000000 | 22536966 | 233470422 | 10.4 | +| uniform random | int32_t | 10000 | 95591 | 1569108 | 16.4 | +| uniform random | int32_t | 100000 | 1790362 | 19785007 | 11.1 | +| uniform random | int32_t | 1000000 | 22874571 | 233358497 | 10.2 | +| uniform random | float | 10000 | 113316 | 1668407 | 14.7 | +| uniform random | float | 100000 | 1920018 | 21815024 | 11.4 | +| uniform random | float | 1000000 | 24776954 | 256867990 | 10.4 | +| uniform random | uint64_t | 10000 | 233501 | 1537649 | 6.6 | +| uniform random | uint64_t | 100000 | 3991372 | 19559859 | 4.9 | +| uniform random | uint64_t | 1000000 | 49818870 | 232687666 | 4.7 | +| uniform random | int64_t | 10000 | 228000 | 1445131 | 6.3 | +| uniform random | int64_t | 100000 | 3892092 | 18917322 | 4.9 | +| uniform random | int64_t | 1000000 | 48957088 | 235100259 | 4.8 | +| uniform random | double | 10000 | 180307 | 1702801 | 9.4 | +| uniform random | double | 100000 | 3596886 | 21849587 | 6.1 | +| uniform random | double | 1000000 | 47724381 | 258014177 | 5.4 | +| uniform random | uint16_t | 10000 | 84732 | 1548275 | 18.3 | +| uniform random | uint16_t | 100000 | 1406417 | 19632858 | 14.0 | +| uniform random | uint16_t | 1000000 | 17119960 | 214085305 | 12.5 | +| uniform random | int16_t | 10000 | 84703 | 1547726 | 18.3 | +| uniform random | int16_t | 100000 | 1442726 | 19705242 | 13.7 | +| uniform random | int16_t | 1000000 | 20210224 | 212137465 | 10.5 | +|-----------------+-------------+------------+-----------------+-----------+----------| \ No newline at end of file diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 7c5ddc5a..487062e4 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -347,5 +347,52 @@ replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) nan_count -= 1; } } +/* + * Assumes zmm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge( + zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), zmm), + 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); + zmm = cmp_merge( + zmm, + vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), + 0xCC); + zmm = cmp_merge( + zmm, vtype::template shuffle(zmm), 0xAA); + return zmm; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 8 + int64_t size = (right - left) / 8; + using zmm_t = typename vtype::zmm_t; + __m512i rand_index = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + // pivot will never be a nan, since there are no nan's! + zmm_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; +} #endif \ No newline at end of file diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 39781884..9cdfbcd9 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -774,36 +774,6 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); } -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *keys, - uint64_t *indexes, - const int64_t left, - const int64_t right) -{ - // median of 8 - int64_t size = (right - left) / 8; - using zmm_t = typename vtype::zmm_t; - using index_type = zmm_vector::zmm_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t key_vec = vtype::template i64gather(rand_index, keys); - - index_type index_vec; - zmm_t sort; - index_vec = zmm_vector::template i64gather( - rand_index, indexes); - sort = sort_zmm_64bit(key_vec, index_vec); - // pivot will never be a nan, since there are no nan's! - - return ((type_t *)&sort)[4]; -} - template void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) { @@ -862,7 +832,7 @@ void qsort_64bit_(type_t *keys, return; } - type_t pivot = get_pivot_64bit(keys, indexes, left, right); + type_t pivot = get_pivot_64bit(keys, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); int64_t pivot_index = partition_avx512( diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 7ba9b52f..62000549 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -16,32 +16,6 @@ */ // ZMM 7, 6, 5, 4, 3, 2, 1, 0 -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), zmm), - 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), - 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - return zmm; -} - // Assumes zmm is bitonic and performs a recursive half cleaner template X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) @@ -404,28 +378,6 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); } -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - const int64_t left, - const int64_t right) -{ - // median of 8 - int64_t size = (right - left) / 8; - using zmm_t = typename vtype::zmm_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); - // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; -} - template static void qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) From c083424653c54874066841158d3f67000f21b3ec Mon Sep 17 00:00:00 2001 From: ruclz Date: Fri, 10 Mar 2023 09:41:48 +0800 Subject: [PATCH 15/16] Fix the authors name. --- benchmarks/bench.hpp | 1 + src/avx512-64bit-common.h | 7 +++++++ src/avx512-64bit-keyvaluesort.hpp | 3 ++- src/avx512-common-keyvaluesort.h | 6 ++---- utils/rand_array.h | 1 - 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench.hpp b/benchmarks/bench.hpp index 073c7a39..d54f61ac 100644 --- a/benchmarks/bench.hpp +++ b/benchmarks/bench.hpp @@ -79,6 +79,7 @@ std::tuple bench_sort(const std::vector arr, / lastfew; return std::make_tuple(avx_sort, std_sort); } + template std::tuple bench_sort_kv(const std::vector keys, diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 487062e4..5c3aeb5e 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -1,3 +1,10 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Liu Zhuan + * Tang Xi + * ****************************************************************/ + #ifndef AVX512_64BIT_COMMOM #define AVX512_64BIT_COMMOM #include "avx512-common-qsort.h" diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 9cdfbcd9..8140be97 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -1,7 +1,8 @@ /******************************************************************* * Copyright (C) 2022 Intel Corporation * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli + * Authors: Liu Zhuan + * Tang Xi * ****************************************************************/ #ifndef AVX512_QSORT_64BIT_KV diff --git a/src/avx512-common-keyvaluesort.h b/src/avx512-common-keyvaluesort.h index 114203e0..f2821072 100644 --- a/src/avx512-common-keyvaluesort.h +++ b/src/avx512-common-keyvaluesort.h @@ -2,8 +2,8 @@ * Copyright (C) 2022 Intel Corporation * Copyright (C) 2021 Serge Sans Paille * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli - * Serge Sans Paille + * Authors: Liu Zhuan + * Tang Xi * ****************************************************************/ #ifndef AVX512_QSORT_COMMON_KV @@ -39,14 +39,12 @@ template void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); using index_t = __m512i; -//using index_type = zmm_vector; template > static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) { - //COEX(key1,key2); mm_t key_t1 = vtype::min(key1, key2); mm_t key_t2 = vtype::max(key1, key2); diff --git a/utils/rand_array.h b/utils/rand_array.h index efa78881..804226f7 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -36,7 +36,6 @@ static std::vector get_uniform_rand_array( std::mt19937 gen(rd()); std::uniform_real_distribution dis(min, max); std::vector arr; - //std::cout< Date: Mon, 13 Mar 2023 09:38:41 +0800 Subject: [PATCH 16/16] Change the author of avx512-64bit-common.h file to Rahu. --- src/avx512-64bit-common.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 5c3aeb5e..32a4731e 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -1,8 +1,7 @@ /******************************************************************* * Copyright (C) 2022 Intel Corporation * SPDX-License-Identifier: BSD-3-Clause - * Authors: Liu Zhuan - * Tang Xi + * Authors: Raghuveer Devulapalli * ****************************************************************/ #ifndef AVX512_64BIT_COMMOM