diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 7fc8acf3..75ae7fb1 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -407,4 +407,4 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, return ((type_t *)&sort)[4]; } -#endif \ No newline at end of file +#endif diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 8ed66e14..4c75c481 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -8,95 +8,90 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV -#include "avx512-common-keyvaluesort.h" +#include "avx512-64bit-common.h" -template ::zmm_t> +template X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), 0xCC); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(rev_index, key_zmm), + vtype1::permutexvar(rev_index, key_zmm), index_zmm, - zmm_vector::permutexvar(rev_index, index_zmm), + vtype2::permutexvar(rev_index, index_zmm), 0xF0); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), 0xCC); - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); return key_zmm; } // Assumes zmm is bitonic and performs a recursive half cleaner -template ::zmm_t> -X86_SIMD_SORT_INLINE zmm_t -bitonic_merge_zmm_64bit(zmm_t key_zmm, zmm_vector::zmm_t &index_zmm) +template +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, + index_type &index_zmm) { // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), 0xF0); // 2) half_cleaner[4] - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), index_zmm, - zmm_vector::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), - index_zmm), + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), 0xCC); // 3) half_cleaner[1] - key_zmm = cmp_merge( + key_zmm = cmp_merge( key_zmm, - vtype::template shuffle(key_zmm), + vtype1::template shuffle(key_zmm), index_zmm, - zmm_vector::template shuffle( - index_zmm), + vtype2::template shuffle(index_zmm), 0xAA); return key_zmm; } // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template ::zmm_t> +template X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, zmm_t &key_zmm2, index_type &index_zmm1, @@ -104,162 +99,161 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype::permutexvar(rev_index, key_zmm2); - index_zmm2 = zmm_vector::permutexvar(rev_index, index_zmm2); + key_zmm2 = vtype1::permutexvar(rev_index, key_zmm2); + index_zmm2 = vtype2::permutexvar(rev_index, index_zmm2); - zmm_t key_zmm3 = vtype::min(key_zmm1, key_zmm2); - zmm_t key_zmm4 = vtype::max(key_zmm1, key_zmm2); + zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); + zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); - index_type index_zmm3 = zmm_vector::mask_mov( - index_zmm2, vtype::eq(key_zmm3, key_zmm1), index_zmm1); - index_type index_zmm4 = zmm_vector::mask_mov( - index_zmm1, vtype::eq(key_zmm3, key_zmm1), index_zmm2); + index_type index_zmm3 = vtype2::mask_mov( + index_zmm2, vtype1::eq(key_zmm3, key_zmm1), index_zmm1); + index_type index_zmm4 = vtype2::mask_mov( + index_zmm1, vtype1::eq(key_zmm3, key_zmm1), index_zmm2); // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); index_zmm1 = index_zmm3; index_zmm2 = index_zmm4; } // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive // half cleaner -template ::zmm_t> +template X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); // 1) First step of a merging network - zmm_t key_zmm2r = vtype::permutexvar(rev_index, key_zmm[2]); - zmm_t key_zmm3r = vtype::permutexvar(rev_index, key_zmm[3]); - index_type index_zmm2r - = zmm_vector::permutexvar(rev_index, index_zmm[2]); - index_type index_zmm3r - = zmm_vector::permutexvar(rev_index, index_zmm[3]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm3r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm2r); - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm2r); - - index_type index_zmm_t1 = zmm_vector::mask_mov( - index_zmm3r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); - index_type index_zmm_t2 = zmm_vector::mask_mov( - index_zmm2r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); + zmm_t key_zmm2r = vtype1::permutexvar(rev_index, key_zmm[2]); + zmm_t key_zmm3r = vtype1::permutexvar(rev_index, key_zmm[3]); + index_type index_zmm2r = vtype2::permutexvar(rev_index, index_zmm[2]); + index_type index_zmm3r = vtype2::permutexvar(rev_index, index_zmm[3]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm3r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm2r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); // 2) Recursive half clearer: 16 - zmm_t key_zmm_t3 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t4 = vtype::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t3 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t4 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - zmm_t key_zmm0 = vtype::min(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm1 = vtype::max(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm2 = vtype::min(key_zmm_t3, key_zmm_t4); - zmm_t key_zmm3 = vtype::max(key_zmm_t3, key_zmm_t4); - - index_type index_zmm0 = zmm_vector::mask_mov( - index_zmm_t2, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t1); - index_type index_zmm1 = zmm_vector::mask_mov( - index_zmm_t1, vtype::eq(key_zmm0, key_zmm_t1), index_zmm_t2); - index_type index_zmm2 = zmm_vector::mask_mov( - index_zmm_t4, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t3); - index_type index_zmm3 = zmm_vector::mask_mov( - index_zmm_t3, vtype::eq(key_zmm2, key_zmm_t3), index_zmm_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t4 = vtype1::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t3 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t4 = vtype2::permutexvar(rev_index, index_zmm_m1); + + zmm_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); + zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); + + index_type index_zmm0 = vtype2::mask_mov( + index_zmm_t2, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_type index_zmm1 = vtype2::mask_mov( + index_zmm_t1, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_type index_zmm2 = vtype2::mask_mov( + index_zmm_t4, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_type index_zmm3 = vtype2::mask_mov( + index_zmm_t3, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); index_zmm[0] = index_zmm0; index_zmm[1] = index_zmm1; index_zmm[2] = index_zmm2; index_zmm[3] = index_zmm3; } -template ::zmm_t> + +template X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm4r = vtype::permutexvar(rev_index, key_zmm[4]); - zmm_t key_zmm5r = vtype::permutexvar(rev_index, key_zmm[5]); - zmm_t key_zmm6r = vtype::permutexvar(rev_index, key_zmm[6]); - zmm_t key_zmm7r = vtype::permutexvar(rev_index, key_zmm[7]); - index_type index_zmm4r - = zmm_vector::permutexvar(rev_index, index_zmm[4]); - index_type index_zmm5r - = zmm_vector::permutexvar(rev_index, index_zmm[5]); - index_type index_zmm6r - = zmm_vector::permutexvar(rev_index, index_zmm[6]); - index_type index_zmm7r - = zmm_vector::permutexvar(rev_index, index_zmm[7]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm7r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm6r); - zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm5r); - zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm4r); - - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm7r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm6r); - zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm4r); - - index_type index_zmm_t1 = zmm_vector::mask_mov( - index_zmm7r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); - index_type index_zmm_t2 = zmm_vector::mask_mov( - index_zmm6r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); - index_type index_zmm_t3 = zmm_vector::mask_mov( - index_zmm5r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = zmm_vector::mask_mov( - index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); - index_type index_zmm_t4 = zmm_vector::mask_mov( - index_zmm4r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = zmm_vector::mask_mov( - index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); - - zmm_t key_zmm_t5 = vtype::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t6 = vtype::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t7 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t8 = vtype::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t5 - = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_type index_zmm_t6 - = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_type index_zmm_t7 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t8 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + zmm_t key_zmm4r = vtype1::permutexvar(rev_index, key_zmm[4]); + zmm_t key_zmm5r = vtype1::permutexvar(rev_index, key_zmm[5]); + zmm_t key_zmm6r = vtype1::permutexvar(rev_index, key_zmm[6]); + zmm_t key_zmm7r = vtype1::permutexvar(rev_index, key_zmm[7]); + index_type index_zmm4r = vtype2::permutexvar(rev_index, index_zmm[4]); + index_type index_zmm5r = vtype2::permutexvar(rev_index, index_zmm[5]); + index_type index_zmm6r = vtype2::permutexvar(rev_index, index_zmm[6]); + index_type index_zmm7r = vtype2::permutexvar(rev_index, index_zmm[7]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); + zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); + zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); + + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); + zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); + zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm7r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm6r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_type index_zmm_t3 = vtype2::mask_mov( + index_zmm5r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_type index_zmm_t4 = vtype2::mask_mov( + index_zmm4r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); + + zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t7 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t8 = vtype1::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t5 = vtype2::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t6 = vtype2::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t7 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t8 = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] + = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); index_zmm[0] = index_zmm_t1; index_zmm[1] = index_zmm_t2; @@ -270,159 +264,170 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, index_zmm[6] = index_zmm_t7; index_zmm[7] = index_zmm_t8; } -template ::zmm_t> + +template X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_type *index_zmm) { const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm8r = vtype::permutexvar(rev_index, key_zmm[8]); - zmm_t key_zmm9r = vtype::permutexvar(rev_index, key_zmm[9]); - zmm_t key_zmm10r = vtype::permutexvar(rev_index, key_zmm[10]); - zmm_t key_zmm11r = vtype::permutexvar(rev_index, key_zmm[11]); - zmm_t key_zmm12r = vtype::permutexvar(rev_index, key_zmm[12]); - zmm_t key_zmm13r = vtype::permutexvar(rev_index, key_zmm[13]); - zmm_t key_zmm14r = vtype::permutexvar(rev_index, key_zmm[14]); - zmm_t key_zmm15r = vtype::permutexvar(rev_index, key_zmm[15]); - - index_type index_zmm8r - = zmm_vector::permutexvar(rev_index, index_zmm[8]); - index_type index_zmm9r - = zmm_vector::permutexvar(rev_index, index_zmm[9]); - index_type index_zmm10r - = zmm_vector::permutexvar(rev_index, index_zmm[10]); - index_type index_zmm11r - = zmm_vector::permutexvar(rev_index, index_zmm[11]); - index_type index_zmm12r - = zmm_vector::permutexvar(rev_index, index_zmm[12]); - index_type index_zmm13r - = zmm_vector::permutexvar(rev_index, index_zmm[13]); - index_type index_zmm14r - = zmm_vector::permutexvar(rev_index, index_zmm[14]); - index_type index_zmm15r - = zmm_vector::permutexvar(rev_index, index_zmm[15]); - - zmm_t key_zmm_t1 = vtype::min(key_zmm[0], key_zmm15r); - zmm_t key_zmm_t2 = vtype::min(key_zmm[1], key_zmm14r); - zmm_t key_zmm_t3 = vtype::min(key_zmm[2], key_zmm13r); - zmm_t key_zmm_t4 = vtype::min(key_zmm[3], key_zmm12r); - zmm_t key_zmm_t5 = vtype::min(key_zmm[4], key_zmm11r); - zmm_t key_zmm_t6 = vtype::min(key_zmm[5], key_zmm10r); - zmm_t key_zmm_t7 = vtype::min(key_zmm[6], key_zmm9r); - zmm_t key_zmm_t8 = vtype::min(key_zmm[7], key_zmm8r); - - zmm_t key_zmm_m1 = vtype::max(key_zmm[0], key_zmm15r); - zmm_t key_zmm_m2 = vtype::max(key_zmm[1], key_zmm14r); - zmm_t key_zmm_m3 = vtype::max(key_zmm[2], key_zmm13r); - zmm_t key_zmm_m4 = vtype::max(key_zmm[3], key_zmm12r); - zmm_t key_zmm_m5 = vtype::max(key_zmm[4], key_zmm11r); - zmm_t key_zmm_m6 = vtype::max(key_zmm[5], key_zmm10r); - zmm_t key_zmm_m7 = vtype::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype::max(key_zmm[7], key_zmm8r); - - index_type index_zmm_t1 = zmm_vector::mask_mov( - index_zmm15r, vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = zmm_vector::mask_mov( - index_zmm[0], vtype::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); - index_type index_zmm_t2 = zmm_vector::mask_mov( - index_zmm14r, vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = zmm_vector::mask_mov( - index_zmm[1], vtype::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); - index_type index_zmm_t3 = zmm_vector::mask_mov( - index_zmm13r, vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = zmm_vector::mask_mov( - index_zmm[2], vtype::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); - index_type index_zmm_t4 = zmm_vector::mask_mov( - index_zmm12r, vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = zmm_vector::mask_mov( - index_zmm[3], vtype::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); - - index_type index_zmm_t5 = zmm_vector::mask_mov( - index_zmm11r, vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); - index_type index_zmm_m5 = zmm_vector::mask_mov( - index_zmm[4], vtype::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); - index_type index_zmm_t6 = zmm_vector::mask_mov( - index_zmm10r, vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); - index_type index_zmm_m6 = zmm_vector::mask_mov( - index_zmm[5], vtype::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); - index_type index_zmm_t7 = zmm_vector::mask_mov( - index_zmm9r, vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); - index_type index_zmm_m7 = zmm_vector::mask_mov( - index_zmm[6], vtype::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); - index_type index_zmm_t8 = zmm_vector::mask_mov( - index_zmm8r, vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); - index_type index_zmm_m8 = zmm_vector::mask_mov( - index_zmm[7], vtype::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); - - zmm_t key_zmm_t9 = vtype::permutexvar(rev_index, key_zmm_m8); - zmm_t key_zmm_t10 = vtype::permutexvar(rev_index, key_zmm_m7); - zmm_t key_zmm_t11 = vtype::permutexvar(rev_index, key_zmm_m6); - zmm_t key_zmm_t12 = vtype::permutexvar(rev_index, key_zmm_m5); - zmm_t key_zmm_t13 = vtype::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t14 = vtype::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t15 = vtype::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t16 = vtype::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t9 - = zmm_vector::permutexvar(rev_index, index_zmm_m8); - index_type index_zmm_t10 - = zmm_vector::permutexvar(rev_index, index_zmm_m7); - index_type index_zmm_t11 - = zmm_vector::permutexvar(rev_index, index_zmm_m6); - index_type index_zmm_t12 - = zmm_vector::permutexvar(rev_index, index_zmm_m5); - index_type index_zmm_t13 - = zmm_vector::permutexvar(rev_index, index_zmm_m4); - index_type index_zmm_t14 - = zmm_vector::permutexvar(rev_index, index_zmm_m3); - index_type index_zmm_t15 - = zmm_vector::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t16 - = zmm_vector::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); - COEX(key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); - COEX(key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); - COEX(key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); - COEX(key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); - COEX(key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); - COEX(key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); - COEX(key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); - COEX(key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + zmm_t key_zmm8r = vtype1::permutexvar(rev_index, key_zmm[8]); + zmm_t key_zmm9r = vtype1::permutexvar(rev_index, key_zmm[9]); + zmm_t key_zmm10r = vtype1::permutexvar(rev_index, key_zmm[10]); + zmm_t key_zmm11r = vtype1::permutexvar(rev_index, key_zmm[11]); + zmm_t key_zmm12r = vtype1::permutexvar(rev_index, key_zmm[12]); + zmm_t key_zmm13r = vtype1::permutexvar(rev_index, key_zmm[13]); + zmm_t key_zmm14r = vtype1::permutexvar(rev_index, key_zmm[14]); + zmm_t key_zmm15r = vtype1::permutexvar(rev_index, key_zmm[15]); + + index_type index_zmm8r = vtype2::permutexvar(rev_index, index_zmm[8]); + index_type index_zmm9r = vtype2::permutexvar(rev_index, index_zmm[9]); + index_type index_zmm10r = vtype2::permutexvar(rev_index, index_zmm[10]); + index_type index_zmm11r = vtype2::permutexvar(rev_index, index_zmm[11]); + index_type index_zmm12r = vtype2::permutexvar(rev_index, index_zmm[12]); + index_type index_zmm13r = vtype2::permutexvar(rev_index, index_zmm[13]); + index_type index_zmm14r = vtype2::permutexvar(rev_index, index_zmm[14]); + index_type index_zmm15r = vtype2::permutexvar(rev_index, index_zmm[15]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); + zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm13r); + zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm12r); + zmm_t key_zmm_t5 = vtype1::min(key_zmm[4], key_zmm11r); + zmm_t key_zmm_t6 = vtype1::min(key_zmm[5], key_zmm10r); + zmm_t key_zmm_t7 = vtype1::min(key_zmm[6], key_zmm9r); + zmm_t key_zmm_t8 = vtype1::min(key_zmm[7], key_zmm8r); + + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); + zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); + zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); + zmm_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); + zmm_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); + zmm_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); + zmm_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm15r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm14r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_type index_zmm_t3 = vtype2::mask_mov( + index_zmm13r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_type index_zmm_t4 = vtype2::mask_mov( + index_zmm12r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_type index_zmm_t5 = vtype2::mask_mov( + index_zmm11r, vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_type index_zmm_m5 = vtype2::mask_mov( + index_zmm[4], vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_type index_zmm_t6 = vtype2::mask_mov( + index_zmm10r, vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_type index_zmm_m6 = vtype2::mask_mov( + index_zmm[5], vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_type index_zmm_t7 = vtype2::mask_mov( + index_zmm9r, vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_type index_zmm_m7 = vtype2::mask_mov( + index_zmm[6], vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_type index_zmm_t8 = vtype2::mask_mov( + index_zmm8r, vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_type index_zmm_m8 = vtype2::mask_mov( + index_zmm[7], vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); + + zmm_t key_zmm_t9 = vtype1::permutexvar(rev_index, key_zmm_m8); + zmm_t key_zmm_t10 = vtype1::permutexvar(rev_index, key_zmm_m7); + zmm_t key_zmm_t11 = vtype1::permutexvar(rev_index, key_zmm_m6); + zmm_t key_zmm_t12 = vtype1::permutexvar(rev_index, key_zmm_m5); + zmm_t key_zmm_t13 = vtype1::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t14 = vtype1::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t15 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t16 = vtype1::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t9 = vtype2::permutexvar(rev_index, index_zmm_m8); + index_type index_zmm_t10 = vtype2::permutexvar(rev_index, index_zmm_m7); + index_type index_zmm_t11 = vtype2::permutexvar(rev_index, index_zmm_m6); + index_type index_zmm_t12 = vtype2::permutexvar(rev_index, index_zmm_m5); + index_type index_zmm_t13 = vtype2::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t14 = vtype2::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t15 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t16 = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX( + key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX( + key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX( + key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX( + key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX( + key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX( + key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX( + key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX( + key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX( + key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); // - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - key_zmm[8] = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, index_zmm_t16); + key_zmm[0] + = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] + = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, + index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, + index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, + index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, + index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, + index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, + index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, + index_zmm_t16); index_zmm[0] = index_zmm_t1; index_zmm[1] = index_zmm_t2; @@ -441,135 +446,149 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, index_zmm[14] = index_zmm_t15; index_zmm[15] = index_zmm_t16; } -template + +template X86_SIMD_SORT_INLINE void -sort_8_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) { - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype::zmm_t key_zmm - = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys); - - zmm_vector::zmm_t index_zmm = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask, indexes); - vtype::mask_storeu( - keys, load_mask, sort_zmm_64bit(key_zmm, index_zmm)); - zmm_vector::mask_storeu(indexes, load_mask, index_zmm); + typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; + typename vtype1::zmm_t key_zmm + = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); + + typename vtype2::zmm_t index_zmm + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); + vtype1::mask_storeu(keys, + load_mask, + sort_zmm_64bit(key_zmm, index_zmm)); + vtype2::mask_storeu(indexes, load_mask, index_zmm); } -template +template X86_SIMD_SORT_INLINE void -sort_16_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 8) { - sort_8_64bit(keys, indexes, N); + sort_8_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using index_type = zmm_vector::zmm_t; + using zmm_t = typename vtype1::zmm_t; + using index_type = typename vtype2::zmm_t; - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - zmm_t key_zmm1 = vtype::loadu(keys); - zmm_t key_zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, keys + 8); + zmm_t key_zmm1 = vtype1::loadu(keys); + zmm_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); - index_type index_zmm1 = zmm_vector::loadu(indexes); - index_type index_zmm2 = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask, indexes + 8); + index_type index_zmm1 = vtype2::loadu(indexes); + index_type index_zmm2 + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes + 8); - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( key_zmm1, key_zmm2, index_zmm1, index_zmm2); - zmm_vector::storeu(indexes, index_zmm1); - zmm_vector::mask_storeu(indexes + 8, load_mask, index_zmm2); + vtype2::storeu(indexes, index_zmm1); + vtype2::mask_storeu(indexes + 8, load_mask, index_zmm2); - vtype::storeu(keys, key_zmm1); - vtype::mask_storeu(keys + 8, load_mask, key_zmm2); + vtype1::storeu(keys, key_zmm1); + vtype1::mask_storeu(keys + 8, load_mask, key_zmm2); } -template +template X86_SIMD_SORT_INLINE void -sort_32_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 16) { - sort_16_64bit(keys, indexes, N); + sort_16_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - using index_type = zmm_vector::zmm_t; + using zmm_t = typename vtype1::zmm_t; + using opmask_t = typename vtype2::opmask_t; + using index_type = typename vtype2::zmm_t; zmm_t key_zmm[4]; index_type index_zmm[4]; - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; load_mask1 = (combined_mask)&0xFF; load_mask2 = (combined_mask >> 8) & 0xFF; - key_zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 16); - key_zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 24); + key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); + key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); - index_zmm[2] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 24); + index_zmm[2] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - zmm_vector::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); + vtype2::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::mask_storeu(keys + 16, load_mask1, key_zmm[2]); - vtype::mask_storeu(keys + 24, load_mask2, key_zmm[3]); + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::mask_storeu(keys + 16, load_mask1, key_zmm[2]); + vtype1::mask_storeu(keys + 24, load_mask2, key_zmm[3]); } -template +template X86_SIMD_SORT_INLINE void -sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 32) { - sort_32_64bit(keys, indexes, N); + sort_32_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using opmask_t = typename vtype::opmask_t; - using index_type = zmm_vector::zmm_t; + using zmm_t = typename vtype1::zmm_t; + using opmask_t = typename vtype1::opmask_t; + using index_type = typename vtype2::zmm_t; zmm_t key_zmm[8]; index_type index_zmm[8]; - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - key_zmm[2] = vtype::loadu(keys + 16); - key_zmm[3] = vtype::loadu(keys + 24); + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + key_zmm[2] = vtype1::loadu(keys + 16); + key_zmm[3] = vtype1::loadu(keys + 24); - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - index_zmm[2] = zmm_vector::loadu(indexes + 16); - index_zmm[3] = zmm_vector::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + index_zmm[2] = vtype2::loadu(indexes + 16); + index_zmm[3] = vtype2::loadu(indexes + 24); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; @@ -579,94 +598,97 @@ sort_64_64bit(type_t *keys, uint64_t *indexes, int32_t N) load_mask2 = (combined_mask >> 8) & 0xFF; load_mask3 = (combined_mask >> 16) & 0xFF; load_mask4 = (combined_mask >> 24) & 0xFF; - key_zmm[4] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 32); - key_zmm[5] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 40); - key_zmm[6] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 48); - key_zmm[7] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 56); - - index_zmm[4] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - bitonic_merge_two_zmm_64bit( + key_zmm[4] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 32); + key_zmm[5] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 40); + key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); + key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); + + index_zmm[4] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::storeu(indexes + 16, index_zmm[2]); - zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - zmm_vector::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - zmm_vector::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - zmm_vector::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::storeu(keys + 16, key_zmm[2]); - vtype::storeu(keys + 24, key_zmm[3]); - vtype::mask_storeu(keys + 32, load_mask1, key_zmm[4]); - vtype::mask_storeu(keys + 40, load_mask2, key_zmm[5]); - vtype::mask_storeu(keys + 48, load_mask3, key_zmm[6]); - vtype::mask_storeu(keys + 56, load_mask4, key_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::storeu(indexes + 16, index_zmm[2]); + vtype2::storeu(indexes + 24, index_zmm[3]); + vtype2::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); + vtype2::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); + vtype2::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); + vtype2::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::storeu(keys + 16, key_zmm[2]); + vtype1::storeu(keys + 24, key_zmm[3]); + vtype1::mask_storeu(keys + 32, load_mask1, key_zmm[4]); + vtype1::mask_storeu(keys + 40, load_mask2, key_zmm[5]); + vtype1::mask_storeu(keys + 48, load_mask3, key_zmm[6]); + vtype1::mask_storeu(keys + 56, load_mask4, key_zmm[7]); } -template +template X86_SIMD_SORT_INLINE void -sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) +sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) { if (N <= 64) { - sort_64_64bit(keys, indexes, N); + sort_64_64bit(keys, indexes, N); return; } - using zmm_t = typename vtype::zmm_t; - using index_type = zmm_vector::zmm_t; - using opmask_t = typename vtype::opmask_t; + using zmm_t = typename vtype1::zmm_t; + using index_type = typename vtype2::zmm_t; + using opmask_t = typename vtype1::opmask_t; zmm_t key_zmm[16]; index_type index_zmm[16]; - key_zmm[0] = vtype::loadu(keys); - key_zmm[1] = vtype::loadu(keys + 8); - key_zmm[2] = vtype::loadu(keys + 16); - key_zmm[3] = vtype::loadu(keys + 24); - key_zmm[4] = vtype::loadu(keys + 32); - key_zmm[5] = vtype::loadu(keys + 40); - key_zmm[6] = vtype::loadu(keys + 48); - key_zmm[7] = vtype::loadu(keys + 56); - - index_zmm[0] = zmm_vector::loadu(indexes); - index_zmm[1] = zmm_vector::loadu(indexes + 8); - index_zmm[2] = zmm_vector::loadu(indexes + 16); - index_zmm[3] = zmm_vector::loadu(indexes + 24); - index_zmm[4] = zmm_vector::loadu(indexes + 32); - index_zmm[5] = zmm_vector::loadu(indexes + 40); - index_zmm[6] = zmm_vector::loadu(indexes + 48); - index_zmm[7] = zmm_vector::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + key_zmm[2] = vtype1::loadu(keys + 16); + key_zmm[3] = vtype1::loadu(keys + 24); + key_zmm[4] = vtype1::loadu(keys + 32); + key_zmm[5] = vtype1::loadu(keys + 40); + key_zmm[6] = vtype1::loadu(keys + 48); + key_zmm[7] = vtype1::loadu(keys + 56); + + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + index_zmm[2] = vtype2::loadu(indexes + 16); + index_zmm[3] = vtype2::loadu(indexes + 24); + index_zmm[4] = vtype2::loadu(indexes + 32); + index_zmm[5] = vtype2::loadu(indexes + 40); + index_zmm[6] = vtype2::loadu(indexes + 48); + index_zmm[7] = vtype2::loadu(indexes + 56); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; @@ -683,100 +705,103 @@ sort_128_64bit(type_t *keys, uint64_t *indexes, int32_t N) load_mask7 = (combined_mask >> 48) & 0xFF; load_mask8 = (combined_mask >> 56) & 0xFF; } - key_zmm[8] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, keys + 64); - key_zmm[9] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, keys + 72); - key_zmm[10] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, keys + 80); - key_zmm[11] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, keys + 88); - key_zmm[12] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, keys + 96); - key_zmm[13] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, keys + 104); - key_zmm[14] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, keys + 112); - key_zmm[15] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, keys + 120); - - index_zmm[8] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] = zmm_vector::mask_loadu( - zmm_vector::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( + key_zmm[8] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 64); + key_zmm[9] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 72); + key_zmm[10] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 80); + key_zmm[11] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 88); + key_zmm[12] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask5, keys + 96); + key_zmm[13] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask6, keys + 104); + key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); + key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); + + index_zmm[8] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( + bitonic_merge_two_zmm_64bit( key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - zmm_vector::storeu(indexes, index_zmm[0]); - zmm_vector::storeu(indexes + 8, index_zmm[1]); - zmm_vector::storeu(indexes + 16, index_zmm[2]); - zmm_vector::storeu(indexes + 24, index_zmm[3]); - zmm_vector::storeu(indexes + 32, index_zmm[4]); - zmm_vector::storeu(indexes + 40, index_zmm[5]); - zmm_vector::storeu(indexes + 48, index_zmm[6]); - zmm_vector::storeu(indexes + 56, index_zmm[7]); - zmm_vector::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - zmm_vector::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - zmm_vector::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - zmm_vector::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - zmm_vector::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - zmm_vector::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - zmm_vector::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - zmm_vector::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - - vtype::storeu(keys, key_zmm[0]); - vtype::storeu(keys + 8, key_zmm[1]); - vtype::storeu(keys + 16, key_zmm[2]); - vtype::storeu(keys + 24, key_zmm[3]); - vtype::storeu(keys + 32, key_zmm[4]); - vtype::storeu(keys + 40, key_zmm[5]); - vtype::storeu(keys + 48, key_zmm[6]); - vtype::storeu(keys + 56, key_zmm[7]); - vtype::mask_storeu(keys + 64, load_mask1, key_zmm[8]); - vtype::mask_storeu(keys + 72, load_mask2, key_zmm[9]); - vtype::mask_storeu(keys + 80, load_mask3, key_zmm[10]); - vtype::mask_storeu(keys + 88, load_mask4, key_zmm[11]); - vtype::mask_storeu(keys + 96, load_mask5, key_zmm[12]); - vtype::mask_storeu(keys + 104, load_mask6, key_zmm[13]); - vtype::mask_storeu(keys + 112, load_mask7, key_zmm[14]); - vtype::mask_storeu(keys + 120, load_mask8, key_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::storeu(indexes + 16, index_zmm[2]); + vtype2::storeu(indexes + 24, index_zmm[3]); + vtype2::storeu(indexes + 32, index_zmm[4]); + vtype2::storeu(indexes + 40, index_zmm[5]); + vtype2::storeu(indexes + 48, index_zmm[6]); + vtype2::storeu(indexes + 56, index_zmm[7]); + vtype2::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); + vtype2::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); + vtype2::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); + vtype2::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); + vtype2::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); + vtype2::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); + vtype2::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); + vtype2::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::storeu(keys + 16, key_zmm[2]); + vtype1::storeu(keys + 24, key_zmm[3]); + vtype1::storeu(keys + 32, key_zmm[4]); + vtype1::storeu(keys + 40, key_zmm[5]); + vtype1::storeu(keys + 48, key_zmm[6]); + vtype1::storeu(keys + 56, key_zmm[7]); + vtype1::mask_storeu(keys + 64, load_mask1, key_zmm[8]); + vtype1::mask_storeu(keys + 72, load_mask2, key_zmm[9]); + vtype1::mask_storeu(keys + 80, load_mask3, key_zmm[10]); + vtype1::mask_storeu(keys + 88, load_mask4, key_zmm[11]); + vtype1::mask_storeu(keys + 96, load_mask5, key_zmm[12]); + vtype1::mask_storeu(keys + 104, load_mask6, key_zmm[13]); + vtype1::mask_storeu(keys + 112, load_mask7, key_zmm[14]); + vtype1::mask_storeu(keys + 120, load_mask8, key_zmm[15]); } -template -void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) +template +void heapify(type1_t *keys, type2_t *indexes, int64_t idx, int64_t size) { int64_t i = idx; while (true) { @@ -790,22 +815,28 @@ void heapify(type_t *keys, uint64_t *indexes, int64_t idx, int64_t size) i = j; } } -template -void heap_sort(type_t *keys, uint64_t *indexes, int64_t size) +template +void heap_sort(type1_t *keys, type2_t *indexes, int64_t size) { for (int64_t i = size / 2 - 1; i >= 0; i--) { - heapify(keys, indexes, i, size); + heapify(keys, indexes, i, size); } for (int64_t i = size - 1; i > 0; i--) { std::swap(keys[0], keys[i]); std::swap(indexes[0], indexes[i]); - heapify(keys, indexes, 0, i); + heapify(keys, indexes, 0, i); } } -template -void qsort_64bit_(type_t *keys, - uint64_t *indexes, +template +void qsort_64bit_(type1_t *keys, + type2_t *indexes, int64_t left, int64_t right, int64_t max_iters) @@ -815,7 +846,8 @@ void qsort_64bit_(type_t *keys, */ if (max_iters <= 0) { //std::sort(keys+left,keys+right+1); - heap_sort(keys + left, indexes + left, right - left + 1); + heap_sort( + keys + left, indexes + left, right - left + 1); return; } /* @@ -823,22 +855,23 @@ void qsort_64bit_(type_t *keys, */ if (right + 1 - left <= 128) { - sort_128_64bit( + sort_128_64bit( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot_64bit(keys, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + type1_t pivot = get_pivot_64bit(keys, left, right); + type1_t smallest = vtype1::type_max(); + type1_t biggest = vtype1::type_min(); + int64_t pivot_index = partition_avx512( keys, indexes, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) { - qsort_64bit_( + qsort_64bit_( keys, indexes, left, pivot_index - 1, max_iters - 1); } if (pivot != biggest) { - qsort_64bit_(keys, indexes, pivot_index, right, max_iters - 1); + qsort_64bit_( + keys, indexes, pivot_index, right, max_iters - 1); } } @@ -846,7 +879,7 @@ template <> void avx512_qsort_kv(int64_t *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, int64_t>( + qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } @@ -857,7 +890,7 @@ void avx512_qsort_kv(uint64_t *keys, int64_t arrsize) { if (arrsize > 1) { - qsort_64bit_, uint64_t>( + qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } @@ -867,7 +900,7 @@ void avx512_qsort_kv(double *keys, uint64_t *indexes, int64_t arrsize) { if (arrsize > 1) { int64_t nan_count = replace_nan_with_inf(keys, arrsize); - qsort_64bit_, double>( + qsort_64bit_, zmm_vector>( keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(keys, arrsize, nan_count); } diff --git a/src/avx512-common-keyvaluesort.h b/src/avx512-common-keyvaluesort.h deleted file mode 100644 index f2821072..00000000 --- a/src/avx512-common-keyvaluesort.h +++ /dev/null @@ -1,240 +0,0 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * Copyright (C) 2021 Serge Sans Paille - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Liu Zhuan - * Tang Xi - * ****************************************************************/ - -#ifndef AVX512_QSORT_COMMON_KV -#define AVX512_QSORT_COMMON_KV - -/* - * Quicksort using AVX-512. The ideas and code are based on these two research - * papers [1] and [2]. On a high level, the idea is to vectorize quicksort - * partitioning using AVX-512 compressstore instructions. If the array size is - * < 128, then use Bitonic sorting network implemented on 512-bit registers. - * The precise network definitions depend on the dtype and are defined in - * separate files: avx512-16bit-qsort.hpp, avx512-32bit-qsort.hpp and - * avx512-64bit-qsort.hpp. Article [4] is a good resource for bitonic sorting - * network. The core implementations of the vectorized qsort functions - * avx512_qsort(T*, int64_t) are modified versions of avx2 quicksort - * presented in the paper [2] and source code associated with that paper [3]. - * - * [1] Fast and Robust Vectorized In-Place Sorting of Primitive Types - * https://drops.dagstuhl.de/opus/volltexte/2021/13775/ - * - * [2] A Novel Hybrid Quicksort Algorithm Vectorized using AVX-512 on Intel - * Skylake https://arxiv.org/pdf/1704.08579.pdf - * - * [3] https://github.com/simd-sorting/fast-and-robust: SPDX-License-Identifier: MIT - * - * [4] http://mitp-content-server.mit.edu:18180/books/content/sectbyfn?collid=books_pres_0&fn=Chapter%2027.pdf&id=8030 - * - */ - -#include "avx512-64bit-common.h" - -template -void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); - -using index_t = __m512i; - -template > -static void COEX(mm_t &key1, mm_t &key2, index_t &index1, index_t &index2) -{ - mm_t key_t1 = vtype::min(key1, key2); - mm_t key_t2 = vtype::max(key1, key2); - - index_t index_t1 - = index_type::mask_mov(index2, vtype::eq(key_t1, key1), index1); - index_t index_t2 - = index_type::mask_mov(index1, vtype::eq(key_t1, key1), index2); - - key1 = key_t1; - key2 = key_t2; - index1 = index_t1; - index2 = index_t2; -} -template > -static inline zmm_t cmp_merge(zmm_t in1, - zmm_t in2, - index_t &indexes1, - index_t indexes2, - opmask_t mask) -{ - zmm_t tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = index_type::mask_mov( - indexes2, vtype::eq(tmp_keys, in1), indexes1); - return tmp_keys; // 0 -> min, 1 -> max -} -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template > -static inline int32_t partition_vec(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - const zmm_t keys_vec, - const index_t indexes_vec, - const zmm_t pivot_vec, - zmm_t *smallest_vec, - zmm_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(keys_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - vtype::mask_compressstoreu( - keys + left, vtype::knot_opmask(gt_mask), keys_vec); - vtype::mask_compressstoreu( - keys + right - amount_gt_pivot, gt_mask, keys_vec); - index_type::mask_compressstoreu( - indexes + left, index_type::knot_opmask(gt_mask), indexes_vec); - index_type::mask_compressstoreu( - indexes + right - amount_gt_pivot, gt_mask, indexes_vec); - *smallest_vec = vtype::min(keys_vec, *smallest_vec); - *biggest_vec = vtype::max(keys_vec, *biggest_vec); - return amount_gt_pivot; -} -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template > -static inline int64_t partition_avx512(type_t *keys, - uint64_t *indexes, - int64_t left, - int64_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, keys[left]); - *biggest = std::max(*biggest, keys[left]); - if (keys[left] > pivot) { - right--; - std::swap(keys[left], keys[right]); - std::swap(indexes[left], indexes[right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using zmm_t = typename vtype::zmm_t; - zmm_t pivot_vec = vtype::set1(pivot); - zmm_t min_vec = vtype::set1(*smallest); - zmm_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - zmm_t keys_vec = vtype::loadu(keys + left); - int32_t amount_gt_pivot; - - index_t indexes_vec = index_type::loadu(indexes + left); - amount_gt_pivot = partition_vec(keys, - indexes, - left, - left + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - zmm_t keys_vec_left = vtype::loadu(keys + left); - zmm_t keys_vec_right = vtype::loadu(keys + (right - vtype::numlanes)); - index_t indexes_vec_left; - index_t indexes_vec_right; - indexes_vec_left = index_type::loadu(indexes + left); - indexes_vec_right = index_type::loadu(indexes + (right - vtype::numlanes)); - - // store points of the vectors - int64_t r_store = right - vtype::numlanes; - int64_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - zmm_t keys_vec; - index_t indexes_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - keys_vec = vtype::loadu(keys + right); - indexes_vec = index_type::loadu(indexes + right); - } - else { - keys_vec = vtype::loadu(keys + left); - indexes_vec = index_type::loadu(indexes + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot; - - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec, - indexes_vec, - pivot_vec, - &min_vec, - &max_vec); - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot; - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - r_store + vtype::numlanes, - keys_vec_left, - indexes_vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(keys, - indexes, - l_store, - l_store + vtype::numlanes, - keys_vec_right, - indexes_vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} -#endif // AVX512_QSORT_COMMON_KV diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 0e0ad818..b07b34d2 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -4,6 +4,8 @@ * SPDX-License-Identifier: BSD-3-Clause * Authors: Raghuveer Devulapalli * Serge Sans Paille + * Liu Zhuan + * Tang Xi * ****************************************************************/ #ifndef AVX512_QSORT_COMMON @@ -86,6 +88,7 @@ template struct zmm_vector; +// Regular quicksort routines: template void avx512_qsort(T *arr, int64_t arrsize); void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); @@ -106,6 +109,10 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize) avx512_qsort_fp16(arr, k - 1); } +// key-value sort routines +template +void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize); + template bool comparison_func(const T &a, const T &b) { @@ -379,4 +386,211 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, *biggest = vtype::reducemax(max_vec); return l_store; } + +// Key-value sort helper functions + +template +static void COEX(zmm_t1 &key1, zmm_t1 &key2, zmm_t2 &index1, zmm_t2 &index2) +{ + zmm_t1 key_t1 = vtype1::min(key1, key2); + zmm_t1 key_t2 = vtype1::max(key1, key2); + + zmm_t2 index_t1 + = vtype2::mask_mov(index2, vtype1::eq(key_t1, key1), index1); + zmm_t2 index_t2 + = vtype2::mask_mov(index1, vtype1::eq(key_t1, key1), index2); + + key1 = key_t1; + key2 = key_t2; + index1 = index_t1; + index2 = index_t2; +} +template +static inline zmm_t1 cmp_merge(zmm_t1 in1, + zmm_t1 in2, + zmm_t2 &indexes1, + zmm_t2 indexes2, + opmask_t mask) +{ + zmm_t1 tmp_keys = cmp_merge(in1, in2, mask); + indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); + return tmp_keys; // 0 -> min, 1 -> max +} + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +static inline int32_t partition_vec(type_t1 *keys, + type_t2 *indexes, + int64_t left, + int64_t right, + const zmm_t1 keys_vec, + const zmm_t2 indexes_vec, + const zmm_t1 pivot_vec, + zmm_t1 *smallest_vec, + zmm_t1 *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype1::opmask_t gt_mask = vtype1::ge(keys_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + vtype1::mask_compressstoreu( + keys + left, vtype1::knot_opmask(gt_mask), keys_vec); + vtype1::mask_compressstoreu( + keys + right - amount_gt_pivot, gt_mask, keys_vec); + vtype2::mask_compressstoreu( + indexes + left, vtype2::knot_opmask(gt_mask), indexes_vec); + vtype2::mask_compressstoreu( + indexes + right - amount_gt_pivot, gt_mask, indexes_vec); + *smallest_vec = vtype1::min(keys_vec, *smallest_vec); + *biggest_vec = vtype1::max(keys_vec, *biggest_vec); + return amount_gt_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +static inline int64_t partition_avx512(type_t1 *keys, + type_t2 *indexes, + int64_t left, + int64_t right, + type_t1 pivot, + type_t1 *smallest, + type_t1 *biggest) +{ + /* make array length divisible by vtype1::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype1::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, keys[left]); + *biggest = std::max(*biggest, keys[left]); + if (keys[left] > pivot) { + right--; + std::swap(keys[left], keys[right]); + std::swap(indexes[left], indexes[right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype1::numlanes elements in the array */ + + zmm_t1 pivot_vec = vtype1::set1(pivot); + zmm_t1 min_vec = vtype1::set1(*smallest); + zmm_t1 max_vec = vtype1::set1(*biggest); + + if (right - left == vtype1::numlanes) { + zmm_t1 keys_vec = vtype1::loadu(keys + left); + int32_t amount_gt_pivot; + + zmm_t2 indexes_vec = vtype2::loadu(indexes + left); + amount_gt_pivot = partition_vec(keys, + indexes, + left, + left + vtype1::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + + *smallest = vtype1::reducemin(min_vec); + *biggest = vtype1::reducemax(max_vec); + return left + (vtype1::numlanes - amount_gt_pivot); + } + + // first and last vtype1::numlanes values are partitioned at the end + zmm_t1 keys_vec_left = vtype1::loadu(keys + left); + zmm_t1 keys_vec_right = vtype1::loadu(keys + (right - vtype1::numlanes)); + zmm_t2 indexes_vec_left; + zmm_t2 indexes_vec_right; + indexes_vec_left = vtype2::loadu(indexes + left); + indexes_vec_right = vtype2::loadu(indexes + (right - vtype1::numlanes)); + + // store points of the vectors + int64_t r_store = right - vtype1::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype1::numlanes; + right -= vtype1::numlanes; + while (right - left != 0) { + zmm_t1 keys_vec; + zmm_t2 indexes_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype1::numlanes) - right < left - l_store) { + right -= vtype1::numlanes; + keys_vec = vtype1::loadu(keys + right); + indexes_vec = vtype2::loadu(indexes + right); + } + else { + keys_vec = vtype1::loadu(keys + left); + indexes_vec = vtype2::loadu(indexes + left); + left += vtype1::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot; + + amount_gt_pivot + = partition_vec(keys, + indexes, + l_store, + r_store + vtype1::numlanes, + keys_vec, + indexes_vec, + pivot_vec, + &min_vec, + &max_vec); + r_store -= amount_gt_pivot; + l_store += (vtype1::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot; + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + r_store + vtype1::numlanes, + keys_vec_left, + indexes_vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype1::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(keys, + indexes, + l_store, + l_store + vtype1::numlanes, + keys_vec_right, + indexes_vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype1::numlanes - amount_gt_pivot); + *smallest = vtype1::reducemin(min_vec); + *biggest = vtype1::reducemax(max_vec); + return l_store; +} #endif // AVX512_QSORT_COMMON