From 7d9c5766aa6774887166dc7e5b2bd622cf81d654 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 17 Apr 2023 14:04:21 -0700 Subject: [PATCH 01/11] Add bitonic sorting network of size 256 for 64-bit dtype --- src/avx512-64bit-qsort.hpp | 350 ++++++++++++++++++++++++++++++++++++- 1 file changed, 348 insertions(+), 2 deletions(-) diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index dfb5376f..eb0ccab8 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -172,6 +172,161 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); } +template +X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]); + zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]); + zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]); + zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]); + zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]); + zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]); + zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]); + zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]); + zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]); + zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]); + zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]); + zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]); + zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]); + zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]); + zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]); + zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r); + zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r); + zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r); + zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r); + zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r); + zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r); + zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r); + zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r); + zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r); + zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r); + zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r); + zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r); + zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r); + zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r); + zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r); + zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r)); + zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r)); + zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r)); + zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r)); + zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r)); + zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r)); + zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r)); + zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r)); + zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r)); + zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r)); + zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r)); + zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r)); + zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r)); + zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r)); + zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r)); + zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r)); + // Recusive half clear 16 zmm regs + COEX(zmm_t1, zmm_t9); + COEX(zmm_t2, zmm_t10); + COEX(zmm_t3, zmm_t11); + COEX(zmm_t4, zmm_t12); + COEX(zmm_t5, zmm_t13); + COEX(zmm_t6, zmm_t14); + COEX(zmm_t7, zmm_t15); + COEX(zmm_t8, zmm_t16); + COEX(zmm_t17, zmm_t25); + COEX(zmm_t18, zmm_t26); + COEX(zmm_t19, zmm_t27); + COEX(zmm_t20, zmm_t28); + COEX(zmm_t21, zmm_t29); + COEX(zmm_t22, zmm_t30); + COEX(zmm_t23, zmm_t31); + COEX(zmm_t24, zmm_t32); + // + COEX(zmm_t1, zmm_t5); + COEX(zmm_t2, zmm_t6); + COEX(zmm_t3, zmm_t7); + COEX(zmm_t4, zmm_t8); + COEX(zmm_t9, zmm_t13); + COEX(zmm_t10, zmm_t14); + COEX(zmm_t11, zmm_t15); + COEX(zmm_t12, zmm_t16); + COEX(zmm_t17, zmm_t21); + COEX(zmm_t18, zmm_t22); + COEX(zmm_t19, zmm_t23); + COEX(zmm_t20, zmm_t24); + COEX(zmm_t25, zmm_t29); + COEX(zmm_t26, zmm_t30); + COEX(zmm_t27, zmm_t31); + COEX(zmm_t28, zmm_t32); + // + COEX(zmm_t1, zmm_t3); + COEX(zmm_t2, zmm_t4); + COEX(zmm_t5, zmm_t7); + COEX(zmm_t6, zmm_t8); + COEX(zmm_t9, zmm_t11); + COEX(zmm_t10, zmm_t12); + COEX(zmm_t13, zmm_t15); + COEX(zmm_t14, zmm_t16); + COEX(zmm_t17, zmm_t19); + COEX(zmm_t18, zmm_t20); + COEX(zmm_t21, zmm_t23); + COEX(zmm_t22, zmm_t24); + COEX(zmm_t25, zmm_t27); + COEX(zmm_t26, zmm_t28); + COEX(zmm_t29, zmm_t31); + COEX(zmm_t30, zmm_t32); + // + COEX(zmm_t1, zmm_t2); + COEX(zmm_t3, zmm_t4); + COEX(zmm_t5, zmm_t6); + COEX(zmm_t7, zmm_t8); + COEX(zmm_t9, zmm_t10); + COEX(zmm_t11, zmm_t12); + COEX(zmm_t13, zmm_t14); + COEX(zmm_t15, zmm_t16); + COEX(zmm_t17, zmm_t18); + COEX(zmm_t19, zmm_t20); + COEX(zmm_t21, zmm_t22); + COEX(zmm_t23, zmm_t24); + COEX(zmm_t25, zmm_t26); + COEX(zmm_t27, zmm_t28); + COEX(zmm_t29, zmm_t30); + COEX(zmm_t31, zmm_t32); + // + zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); + zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); + zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); + zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); + zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); + zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); + zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); + zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); + zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); + zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); + zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); + zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); + zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); + zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); + zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); + zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); + zmm[16] = bitonic_merge_zmm_64bit(zmm_t17); + zmm[17] = bitonic_merge_zmm_64bit(zmm_t18); + zmm[18] = bitonic_merge_zmm_64bit(zmm_t19); + zmm[19] = bitonic_merge_zmm_64bit(zmm_t20); + zmm[20] = bitonic_merge_zmm_64bit(zmm_t21); + zmm[21] = bitonic_merge_zmm_64bit(zmm_t22); + zmm[22] = bitonic_merge_zmm_64bit(zmm_t23); + zmm[23] = bitonic_merge_zmm_64bit(zmm_t24); + zmm[24] = bitonic_merge_zmm_64bit(zmm_t25); + zmm[25] = bitonic_merge_zmm_64bit(zmm_t26); + zmm[26] = bitonic_merge_zmm_64bit(zmm_t27); + zmm[27] = bitonic_merge_zmm_64bit(zmm_t28); + zmm[28] = bitonic_merge_zmm_64bit(zmm_t29); + zmm[29] = bitonic_merge_zmm_64bit(zmm_t30); + zmm[30] = bitonic_merge_zmm_64bit(zmm_t31); + zmm[31] = bitonic_merge_zmm_64bit(zmm_t32); +} + template X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) { @@ -371,6 +526,197 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); } +template +X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) +{ + if (N <= 128) { + sort_128_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[32]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[4] = vtype::loadu(arr + 32); + zmm[5] = vtype::loadu(arr + 40); + zmm[6] = vtype::loadu(arr + 48); + zmm[7] = vtype::loadu(arr + 56); + zmm[8] = vtype::loadu(arr + 64); + zmm[9] = vtype::loadu(arr + 72); + zmm[10] = vtype::loadu(arr + 80); + zmm[11] = vtype::loadu(arr + 88); + zmm[12] = vtype::loadu(arr + 96); + zmm[13] = vtype::loadu(arr + 104); + zmm[14] = vtype::loadu(arr + 112); + zmm[15] = vtype::loadu(arr + 120); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); + zmm[8] = sort_zmm_64bit(zmm[8]); + zmm[9] = sort_zmm_64bit(zmm[9]); + zmm[10] = sort_zmm_64bit(zmm[10]); + zmm[11] = sort_zmm_64bit(zmm[11]); + zmm[12] = sort_zmm_64bit(zmm[12]); + zmm[13] = sort_zmm_64bit(zmm[13]); + zmm[14] = sort_zmm_64bit(zmm[14]); + zmm[15] = sort_zmm_64bit(zmm[15]); + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF; + opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF; + opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF; + opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF; + if (N != 256) { + uint64_t combined_mask; + if (N < 192) { + combined_mask = (0x1ull << (N - 128)) - 0x1ull; + load_mask1 = (combined_mask) & 0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + load_mask9 = 0x00; load_mask10 = 0x0; + load_mask11 = 0x00; load_mask12 = 0x00; + load_mask13 = 0x00; load_mask14 = 0x00; + load_mask15 = 0x00; load_mask16 = 0x00; + } + else { + combined_mask = (0x1ull << (N - 192)) - 0x1ull; + load_mask9 = (combined_mask) & 0xFF; + load_mask10 = (combined_mask >> 8) & 0xFF; + load_mask11 = (combined_mask >> 16) & 0xFF; + load_mask12 = (combined_mask >> 24) & 0xFF; + load_mask13 = (combined_mask >> 32) & 0xFF; + load_mask14 = (combined_mask >> 40) & 0xFF; + load_mask15 = (combined_mask >> 48) & 0xFF; + load_mask16 = (combined_mask >> 56) & 0xFF; + } + } + zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128); + zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136); + zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144); + zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152); + zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160); + zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168); + zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176); + zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184); + if (N < 192) { + zmm[24] = vtype::zmm_max(); + zmm[25] = vtype::zmm_max(); + zmm[26] = vtype::zmm_max(); + zmm[27] = vtype::zmm_max(); + zmm[28] = vtype::zmm_max(); + zmm[29] = vtype::zmm_max(); + zmm[30] = vtype::zmm_max(); + zmm[31] = vtype::zmm_max(); + } + else { + zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192); + zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200); + zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208); + zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216); + zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224); + zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232); + zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240); + zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248); + } + zmm[16] = sort_zmm_64bit(zmm[16]); + zmm[17] = sort_zmm_64bit(zmm[17]); + zmm[18] = sort_zmm_64bit(zmm[18]); + zmm[19] = sort_zmm_64bit(zmm[19]); + zmm[20] = sort_zmm_64bit(zmm[20]); + zmm[21] = sort_zmm_64bit(zmm[21]); + zmm[22] = sort_zmm_64bit(zmm[22]); + zmm[23] = sort_zmm_64bit(zmm[23]); + zmm[24] = sort_zmm_64bit(zmm[24]); + zmm[25] = sort_zmm_64bit(zmm[25]); + zmm[26] = sort_zmm_64bit(zmm[26]); + zmm[27] = sort_zmm_64bit(zmm[27]); + zmm[28] = sort_zmm_64bit(zmm[28]); + zmm[29] = sort_zmm_64bit(zmm[29]); + zmm[30] = sort_zmm_64bit(zmm[30]); + zmm[31] = sort_zmm_64bit(zmm[31]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); + bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); + bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); + bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); + bitonic_merge_two_zmm_64bit(zmm[16], zmm[17]); + bitonic_merge_two_zmm_64bit(zmm[18], zmm[19]); + bitonic_merge_two_zmm_64bit(zmm[20], zmm[21]); + bitonic_merge_two_zmm_64bit(zmm[22], zmm[23]); + bitonic_merge_two_zmm_64bit(zmm[24], zmm[25]); + bitonic_merge_two_zmm_64bit(zmm[26], zmm[27]); + bitonic_merge_two_zmm_64bit(zmm[28], zmm[29]); + bitonic_merge_two_zmm_64bit(zmm[30], zmm[31]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_four_zmm_64bit(zmm + 8); + bitonic_merge_four_zmm_64bit(zmm + 12); + bitonic_merge_four_zmm_64bit(zmm + 16); + bitonic_merge_four_zmm_64bit(zmm + 20); + bitonic_merge_four_zmm_64bit(zmm + 24); + bitonic_merge_four_zmm_64bit(zmm + 28); + bitonic_merge_eight_zmm_64bit(zmm); + bitonic_merge_eight_zmm_64bit(zmm + 8); + bitonic_merge_eight_zmm_64bit(zmm + 16); + bitonic_merge_eight_zmm_64bit(zmm + 24); + bitonic_merge_sixteen_zmm_64bit(zmm); + bitonic_merge_sixteen_zmm_64bit(zmm + 16); + bitonic_merge_32_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::storeu(arr + 32, zmm[4]); + vtype::storeu(arr + 40, zmm[5]); + vtype::storeu(arr + 48, zmm[6]); + vtype::storeu(arr + 56, zmm[7]); + vtype::storeu(arr + 64, zmm[8]); + vtype::storeu(arr + 72, zmm[9]); + vtype::storeu(arr + 80, zmm[10]); + vtype::storeu(arr + 88, zmm[11]); + vtype::storeu(arr + 96, zmm[12]); + vtype::storeu(arr + 104, zmm[13]); + vtype::storeu(arr + 112, zmm[14]); + vtype::storeu(arr + 120, zmm[15]); + vtype::mask_storeu(arr + 128, load_mask1, zmm[16]); + vtype::mask_storeu(arr + 136, load_mask2, zmm[17]); + vtype::mask_storeu(arr + 144, load_mask3, zmm[18]); + vtype::mask_storeu(arr + 152, load_mask4, zmm[19]); + vtype::mask_storeu(arr + 160, load_mask5, zmm[20]); + vtype::mask_storeu(arr + 168, load_mask6, zmm[21]); + vtype::mask_storeu(arr + 176, load_mask7, zmm[22]); + vtype::mask_storeu(arr + 184, load_mask8, zmm[23]); + if (N > 192) { + vtype::mask_storeu(arr + 192, load_mask9, zmm[24]); + vtype::mask_storeu(arr + 200, load_mask10, zmm[25]); + vtype::mask_storeu(arr + 208, load_mask11, zmm[26]); + vtype::mask_storeu(arr + 216, load_mask12, zmm[27]); + vtype::mask_storeu(arr + 224, load_mask13, zmm[28]); + vtype::mask_storeu(arr + 232, load_mask14, zmm[29]); + vtype::mask_storeu(arr + 240, load_mask15, zmm[30]); + vtype::mask_storeu(arr + 248, load_mask16, zmm[31]); + } + +} + template static void qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) @@ -385,8 +731,8 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) /* * Base case: use bitonic networks to sort arrays <= 128 */ - if (right + 1 - left <= 128) { - sort_128_64bit(arr + left, (int32_t)(right + 1 - left)); + if (right + 1 - left <= 256) { + sort_256_64bit(arr + left, (int32_t)(right + 1 - left)); return; } From 1e0608b31e82a9e3724d240683bb5d193026e8f0 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 17 Apr 2023 15:21:42 -0700 Subject: [PATCH 02/11] Median: use median of 8*8 elements --- src/avx512-64bit-common.h | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 7fc8acf3..f720ddaa 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -390,21 +390,25 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, const int64_t left, const int64_t right) { - // median of 8 + // median of 8x8 elements int64_t size = (right - left) / 8; using zmm_t = typename vtype::zmm_t; - __m512i rand_index = _mm512_set_epi64(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + zmm_t v[8]; + for (int64_t ii = 0; ii < 8; ++ii) { + v[ii] = vtype::loadu(arr + left + ii*size); + } + COEX(v[0], v[1]); COEX(v[2], v[3]); /* step 1 */ + COEX(v[4], v[5]); COEX(v[6], v[7]); + COEX(v[0], v[2]); COEX(v[1], v[3]); /* step 2 */ + COEX(v[4], v[6]); COEX(v[5], v[7]); + COEX(v[0], v[4]); COEX(v[1], v[2]); /* step 3 */ + COEX(v[5], v[6]); COEX(v[3], v[7]); + COEX(v[1], v[5]); COEX(v[2], v[6]); /* step 4 */ + COEX(v[3], v[5]); COEX(v[2], v[4]); /* step 5 */ + COEX(v[3], v[4]); /* step 6 */ // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(rand_vec); + zmm_t sort = sort_zmm_64bit(v[3]); return ((type_t *)&sort)[4]; } -#endif \ No newline at end of file +#endif From e1b7c82681f75bd1af9b240ae8a7b56eae424f0d Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 18 Apr 2023 11:49:38 -0700 Subject: [PATCH 03/11] Unroll the partition algorithm --- src/avx512-common-qsort.h | 113 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 5b6591f0..d47d08f6 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -259,4 +259,117 @@ static inline int64_t partition_avx512(type_t *arr, *biggest = vtype::reducemax(max_vec); return l_store; } + +template +static inline int64_t partition_avx512_unrolled(type_t *arr, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + const int num_unroll = 8; + if (right - left <= 2*num_unroll*vtype::numlanes) { + return partition_avx512(arr, left, right, pivot, smallest, biggest); + } + /* make array length divisible by 8*vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll*vtype::numlanes)); i > 0; --i) { + *smallest = std::min(*smallest, arr[left], comparison_func); + *biggest = std::max(*biggest, arr[left], comparison_func); + if (!comparison_func(arr[left], pivot)) { + std::swap(arr[left], arr[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + // We will now have atleast 16 registers worth of data to process: + // left and right vtype::numlanes values are partitioned at the end + zmm_t vec_left[num_unroll], vec_right[num_unroll]; + #pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes*ii); + vec_right[ii] = vtype::loadu(arr + (right - vtype::numlanes*(num_unroll-ii))); + } + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += num_unroll*vtype::numlanes; + right -= num_unroll*vtype::numlanes; + while (right - left != 0) { + zmm_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll*vtype::numlanes; + #pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + right + ii*vtype::numlanes); + } + } + else { + #pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + left + ii*vtype::numlanes); + } + left += num_unroll*vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + #pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot + = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + curr_vec[ii], + pivot_vec, + &min_vec,pick + &max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + } + + /* partition and save vec_left[8] and vec_right[8] */ + #pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + #pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} #endif // AVX512_QSORT_COMMON From 814736baff11529d7f403739d6ae99f89bd83588 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 18 Apr 2023 12:30:26 -0700 Subject: [PATCH 04/11] Use unrolled parition --- src/avx512-32bit-qsort.hpp | 2 +- src/avx512-64bit-qsort.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index e9e97aa1..ea1f7130 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -648,7 +648,7 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index eb0ccab8..7a16849f 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -739,7 +739,7 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) type_t pivot = get_pivot_64bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); From 2a4f94953cc1f94cd16fc6a6654eaa4f79d62524 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 18 Apr 2023 13:01:29 -0700 Subject: [PATCH 05/11] Revert "Median: use median of 8*8 elements" This reverts commit 6e9a93ccde73425da85bed99b45c2a4a33787730. --- src/avx512-64bit-common.h | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index f720ddaa..7fc8acf3 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -390,25 +390,21 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, const int64_t left, const int64_t right) { - // median of 8x8 elements + // median of 8 int64_t size = (right - left) / 8; using zmm_t = typename vtype::zmm_t; - zmm_t v[8]; - for (int64_t ii = 0; ii < 8; ++ii) { - v[ii] = vtype::loadu(arr + left + ii*size); - } - COEX(v[0], v[1]); COEX(v[2], v[3]); /* step 1 */ - COEX(v[4], v[5]); COEX(v[6], v[7]); - COEX(v[0], v[2]); COEX(v[1], v[3]); /* step 2 */ - COEX(v[4], v[6]); COEX(v[5], v[7]); - COEX(v[0], v[4]); COEX(v[1], v[2]); /* step 3 */ - COEX(v[5], v[6]); COEX(v[3], v[7]); - COEX(v[1], v[5]); COEX(v[2], v[6]); /* step 4 */ - COEX(v[3], v[5]); COEX(v[2], v[4]); /* step 5 */ - COEX(v[3], v[4]); /* step 6 */ + __m512i rand_index = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); // pivot will never be a nan, since there are no nan's! - zmm_t sort = sort_zmm_64bit(v[3]); + zmm_t sort = sort_zmm_64bit(rand_vec); return ((type_t *)&sort)[4]; } -#endif +#endif \ No newline at end of file From 250dbe5b58237f2e91b30e5651fb8a181e0e45f8 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 18 Apr 2023 13:15:24 -0700 Subject: [PATCH 06/11] Fix formatting --- src/avx512-64bit-qsort.hpp | 17 +++++---- src/avx512-common-qsort.h | 73 ++++++++++++++++++++------------------ 2 files changed, 49 insertions(+), 41 deletions(-) diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 7a16849f..ee241178 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -580,7 +580,7 @@ X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) uint64_t combined_mask; if (N < 192) { combined_mask = (0x1ull << (N - 128)) - 0x1ull; - load_mask1 = (combined_mask) & 0xFF; + load_mask1 = (combined_mask)&0xFF; load_mask2 = (combined_mask >> 8) & 0xFF; load_mask3 = (combined_mask >> 16) & 0xFF; load_mask4 = (combined_mask >> 24) & 0xFF; @@ -588,14 +588,18 @@ X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) load_mask6 = (combined_mask >> 40) & 0xFF; load_mask7 = (combined_mask >> 48) & 0xFF; load_mask8 = (combined_mask >> 56) & 0xFF; - load_mask9 = 0x00; load_mask10 = 0x0; - load_mask11 = 0x00; load_mask12 = 0x00; - load_mask13 = 0x00; load_mask14 = 0x00; - load_mask15 = 0x00; load_mask16 = 0x00; + load_mask9 = 0x00; + load_mask10 = 0x0; + load_mask11 = 0x00; + load_mask12 = 0x00; + load_mask13 = 0x00; + load_mask14 = 0x00; + load_mask15 = 0x00; + load_mask16 = 0x00; } else { combined_mask = (0x1ull << (N - 192)) - 0x1ull; - load_mask9 = (combined_mask) & 0xFF; + load_mask9 = (combined_mask)&0xFF; load_mask10 = (combined_mask >> 8) & 0xFF; load_mask11 = (combined_mask >> 16) & 0xFF; load_mask12 = (combined_mask >> 24) & 0xFF; @@ -714,7 +718,6 @@ X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) vtype::mask_storeu(arr + 240, load_mask15, zmm[30]); vtype::mask_storeu(arr + 248, load_mask16, zmm[31]); } - } template diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index d47d08f6..cb2ed566 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -269,11 +269,13 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, type_t *biggest) { const int num_unroll = 8; - if (right - left <= 2*num_unroll*vtype::numlanes) { - return partition_avx512(arr, left, right, pivot, smallest, biggest); + if (right - left <= 2 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, left, right, pivot, smallest, biggest); } /* make array length divisible by 8*vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll*vtype::numlanes)); i > 0; --i) { + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { *smallest = std::min(*smallest, arr[left], comparison_func); *biggest = std::max(*biggest, arr[left], comparison_func); if (!comparison_func(arr[left], pivot)) { @@ -295,17 +297,18 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, // We will now have atleast 16 registers worth of data to process: // left and right vtype::numlanes values are partitioned at the end zmm_t vec_left[num_unroll], vec_right[num_unroll]; - #pragma GCC unroll 8 +#pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes*ii); - vec_right[ii] = vtype::loadu(arr + (right - vtype::numlanes*(num_unroll-ii))); + vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii); + vec_right[ii] = vtype::loadu( + arr + (right - vtype::numlanes * (num_unroll - ii))); } // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; // indices for loading the elements - left += num_unroll*vtype::numlanes; - right -= num_unroll*vtype::numlanes; + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; while (right - left != 0) { zmm_t curr_vec[num_unroll]; /* @@ -314,21 +317,21 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, * otherwise from the left side */ if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll*vtype::numlanes; - #pragma GCC unroll 8 + right -= num_unroll * vtype::numlanes; +#pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - curr_vec[ii] = vtype::loadu(arr + right + ii*vtype::numlanes); + curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes); } } else { - #pragma GCC unroll 8 +#pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - curr_vec[ii] = vtype::loadu(arr + left + ii*vtype::numlanes); + curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes); } - left += num_unroll*vtype::numlanes; + left += num_unroll * vtype::numlanes; } - // partition the current vector and save it on both sides of the array - #pragma GCC unroll 8 +// partition the current vector and save it on both sides of the array +#pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_ge_pivot = partition_vec(arr, @@ -336,35 +339,37 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, r_store + vtype::numlanes, curr_vec[ii], pivot_vec, - &min_vec,pick + &min_vec, &max_vec); l_store += (vtype::numlanes - amount_ge_pivot); r_store -= amount_ge_pivot; } } - /* partition and save vec_left[8] and vec_right[8] */ - #pragma GCC unroll 8 +/* partition and save vec_left[8] and vec_right[8] */ +#pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot = partition_vec(arr, - l_store, - r_store + vtype::numlanes, - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_ge_pivot + = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_ge_pivot); r_store -= amount_ge_pivot; } - #pragma GCC unroll 8 +#pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_ge_pivot = partition_vec(arr, - l_store, - r_store + vtype::numlanes, - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_ge_pivot + = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_ge_pivot); r_store -= amount_ge_pivot; } From 2af3820eac2959394c572ba8b39e26d1ee7833e4 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 20 Apr 2023 21:13:18 -0700 Subject: [PATCH 07/11] Make num_unroll a compile time constant --- src/avx512-32bit-qsort.hpp | 4 ++-- src/avx512-64bit-qsort.hpp | 4 ++-- src/avx512-common-qsort.h | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index ea1f7130..79794566 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -648,7 +648,7 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); @@ -680,7 +680,7 @@ qselect_32bit_(type_t *arr, int64_t pos, type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) qselect_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index ee241178..6d59f2f2 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -742,7 +742,7 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) type_t pivot = get_pivot_64bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); @@ -774,7 +774,7 @@ qselect_64bit_(type_t *arr, int64_t pos, type_t pivot = get_pivot_64bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) qselect_64bit_(arr, pos, left, pivot_index - 1, max_iters - 1); diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index cb2ed566..5d105daa 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -260,7 +260,9 @@ static inline int64_t partition_avx512(type_t *arr, return l_store; } -template +template static inline int64_t partition_avx512_unrolled(type_t *arr, int64_t left, int64_t right, @@ -268,7 +270,6 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, type_t *smallest, type_t *biggest) { - const int num_unroll = 8; if (right - left <= 2 * num_unroll * vtype::numlanes) { return partition_avx512( arr, left, right, pivot, smallest, biggest); From 2ea83341bb250afff5f6ab641804ccb657f0d562 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 25 Apr 2023 09:38:21 -0700 Subject: [PATCH 08/11] Fix formatting --- src/avx512-16bit-common.h | 9 +++++---- src/avx512-32bit-qsort.hpp | 13 +++++++------ src/avx512-64bit-qsort.hpp | 9 +++++---- src/avx512-common-qsort.h | 3 ++- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 6e0743d6..0c819946 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -290,10 +290,11 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template -static void -qselect_16bit_(type_t *arr, int64_t pos, - int64_t left, int64_t right, - int64_t max_iters) +static void qselect_16bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 79794566..c4061ddf 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -648,7 +648,7 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); @@ -657,10 +657,11 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template -static void -qselect_32bit_(type_t *arr, int64_t pos, - int64_t left, int64_t right, - int64_t max_iters) +static void qselect_32bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -680,7 +681,7 @@ qselect_32bit_(type_t *arr, int64_t pos, type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) qselect_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index 6d59f2f2..1cbcd388 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -751,10 +751,11 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template -static void -qselect_64bit_(type_t *arr, int64_t pos, - int64_t left, int64_t right, - int64_t max_iters) +static void qselect_64bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 5d105daa..0e0ad818 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -95,7 +95,8 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize); void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize); template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) { +inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) +{ avx512_qselect(arr, k - 1, arrsize); avx512_qsort(arr, k - 1); } From dbed567764e4c7353f9316cb8fea2ef1910c9793 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 25 Apr 2023 10:16:21 -0700 Subject: [PATCH 09/11] Add more tests for qsort --- tests/test_qsort.hpp | 96 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 3 deletions(-) diff --git a/tests/test_qsort.hpp b/tests/test_qsort.hpp index 65a8eaf6..9663ffbb 100644 --- a/tests/test_qsort.hpp +++ b/tests/test_qsort.hpp @@ -10,7 +10,7 @@ class avx512_sort : public ::testing::Test { }; TYPED_TEST_SUITE_P(avx512_sort); -TYPED_TEST_P(avx512_sort, test_arrsizes) +TYPED_TEST_P(avx512_sort, test_random) { if (cpu_has_avx512bw()) { if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { @@ -29,7 +29,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); avx512_qsort(arr.data(), arr.size()); - ASSERT_EQ(sortedarr, arr); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; arr.clear(); sortedarr.clear(); } @@ -39,4 +39,94 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) } } -REGISTER_TYPED_TEST_SUITE_P(avx512_sort, test_arrsizes); +TYPED_TEST_P(avx512_sort, test_reverse) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam) (ii + 1)); + } + std::vector arr; + std::vector sortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* reverse array */ + for (int jj = 0; jj < arrsizes[ii]; ++jj) { + arr.push_back((TypeParam) (arrsizes[ii] - jj)); + } + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} + +TYPED_TEST_P(avx512_sort, test_constant) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam) (ii + 1)); + } + std::vector arr; + std::vector sortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* constant array */ + for (int jj = 0; jj < arrsizes[ii]; ++jj) { + arr.push_back(ii); + } + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} + +TYPED_TEST_P(avx512_sort, test_small_range) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam) (ii + 1)); + } + std::vector arr; + std::vector sortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + arr = get_uniform_rand_array(arrsizes[ii], 20, 1); + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} +REGISTER_TYPED_TEST_SUITE_P(avx512_sort, + test_random, test_reverse, test_constant, test_small_range); From 7a8699519cba8270045de527891cc18565115eac Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 25 Apr 2023 10:20:16 -0700 Subject: [PATCH 10/11] Add more tests for qselect --- tests/test_qselect.hpp | 47 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/test_qselect.hpp b/tests/test_qselect.hpp index cad017bb..809affe5 100644 --- a/tests/test_qselect.hpp +++ b/tests/test_qselect.hpp @@ -5,7 +5,7 @@ class avx512_select : public ::testing::Test { }; TYPED_TEST_SUITE_P(avx512_select); -TYPED_TEST_P(avx512_select, test_arrsizes) +TYPED_TEST_P(avx512_select, test_random) { if (cpu_has_avx512bw()) { if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { @@ -48,4 +48,47 @@ TYPED_TEST_P(avx512_select, test_arrsizes) } } -REGISTER_TYPED_TEST_SUITE_P(avx512_select, test_arrsizes); +TYPED_TEST_P(avx512_select, test_small_range) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + std::vector sortedarr; + std::vector psortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* Random array */ + arr = get_uniform_rand_array(arrsizes[ii], 20, 1); + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + for (size_t k = 0; k < arr.size(); ++k) { + psortedarr = arr; + avx512_qselect(psortedarr.data(), k, psortedarr.size()); + /* index k is correct */ + ASSERT_EQ(sortedarr[k], psortedarr[k]); + /* Check left partition */ + for (size_t jj = 0; jj < k; jj++) { + ASSERT_LE(psortedarr[jj], psortedarr[k]); + } + /* Check right partition */ + for (size_t jj = k+1; jj < arr.size(); jj++) { + ASSERT_GE(psortedarr[jj], psortedarr[k]); + } + psortedarr.clear(); + } + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} + +REGISTER_TYPED_TEST_SUITE_P(avx512_select, test_random, test_small_range); From 74986fad20d71321ef0f2bc6e234070d93e5d7b9 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 25 Apr 2023 10:21:09 -0700 Subject: [PATCH 11/11] Formatting fix --- tests/test_keyvalue.cpp | 2 +- tests/test_partial_qsort.hpp | 3 ++- tests/test_qselect.hpp | 10 ++++++---- tests/test_qsort.hpp | 13 ++++++++----- tests/test_qsortfp16.cpp | 8 +++++--- tests/test_sort.cpp | 8 +++++--- 6 files changed, 27 insertions(+), 17 deletions(-) diff --git a/tests/test_keyvalue.cpp b/tests/test_keyvalue.cpp index 0cb1ca6f..b9cca554 100644 --- a/tests/test_keyvalue.cpp +++ b/tests/test_keyvalue.cpp @@ -4,8 +4,8 @@ * *******************************************/ #include "avx512-64bit-keyvaluesort.hpp" -#include "rand_array.h" #include "cpuinfo.h" +#include "rand_array.h" #include #include diff --git a/tests/test_partial_qsort.hpp b/tests/test_partial_qsort.hpp index 5c08064e..4ba5caa8 100644 --- a/tests/test_partial_qsort.hpp +++ b/tests/test_partial_qsort.hpp @@ -30,7 +30,8 @@ TYPED_TEST_P(avx512_partial_sort, test_ranges) int k = get_uniform_rand_array(1, arrsize, 1).front(); /* Sort the range and verify all the required elements match the presorted set */ - avx512_partial_qsort(psortedarr.data(), k, psortedarr.size()); + avx512_partial_qsort( + psortedarr.data(), k, psortedarr.size()); for (size_t jj = 0; jj < k; jj++) { ASSERT_EQ(sortedarr[jj], psortedarr[jj]); } diff --git a/tests/test_qselect.hpp b/tests/test_qselect.hpp index 809affe5..f0c0c242 100644 --- a/tests/test_qselect.hpp +++ b/tests/test_qselect.hpp @@ -26,7 +26,8 @@ TYPED_TEST_P(avx512_select, test_random) std::sort(sortedarr.begin(), sortedarr.end()); for (size_t k = 0; k < arr.size(); ++k) { psortedarr = arr; - avx512_qselect(psortedarr.data(), k, psortedarr.size()); + avx512_qselect( + psortedarr.data(), k, psortedarr.size()); /* index k is correct */ ASSERT_EQ(sortedarr[k], psortedarr[k]); /* Check left partition */ @@ -34,7 +35,7 @@ TYPED_TEST_P(avx512_select, test_random) ASSERT_LE(psortedarr[jj], psortedarr[k]); } /* Check right partition */ - for (size_t jj = k+1; jj < arr.size(); jj++) { + for (size_t jj = k + 1; jj < arr.size(); jj++) { ASSERT_GE(psortedarr[jj], psortedarr[k]); } psortedarr.clear(); @@ -69,7 +70,8 @@ TYPED_TEST_P(avx512_select, test_small_range) std::sort(sortedarr.begin(), sortedarr.end()); for (size_t k = 0; k < arr.size(); ++k) { psortedarr = arr; - avx512_qselect(psortedarr.data(), k, psortedarr.size()); + avx512_qselect( + psortedarr.data(), k, psortedarr.size()); /* index k is correct */ ASSERT_EQ(sortedarr[k], psortedarr[k]); /* Check left partition */ @@ -77,7 +79,7 @@ TYPED_TEST_P(avx512_select, test_small_range) ASSERT_LE(psortedarr[jj], psortedarr[k]); } /* Check right partition */ - for (size_t jj = k+1; jj < arr.size(); jj++) { + for (size_t jj = k + 1; jj < arr.size(); jj++) { ASSERT_GE(psortedarr[jj], psortedarr[k]); } psortedarr.clear(); diff --git a/tests/test_qsort.hpp b/tests/test_qsort.hpp index 9663ffbb..4dc8a773 100644 --- a/tests/test_qsort.hpp +++ b/tests/test_qsort.hpp @@ -47,14 +47,14 @@ TYPED_TEST_P(avx512_sort, test_reverse) } std::vector arrsizes; for (int64_t ii = 0; ii < 1024; ++ii) { - arrsizes.push_back((TypeParam) (ii + 1)); + arrsizes.push_back((TypeParam)(ii + 1)); } std::vector arr; std::vector sortedarr; for (size_t ii = 0; ii < arrsizes.size(); ++ii) { /* reverse array */ for (int jj = 0; jj < arrsizes[ii]; ++jj) { - arr.push_back((TypeParam) (arrsizes[ii] - jj)); + arr.push_back((TypeParam)(arrsizes[ii] - jj)); } sortedarr = arr; /* Sort with std::sort for comparison */ @@ -78,7 +78,7 @@ TYPED_TEST_P(avx512_sort, test_constant) } std::vector arrsizes; for (int64_t ii = 0; ii < 1024; ++ii) { - arrsizes.push_back((TypeParam) (ii + 1)); + arrsizes.push_back((TypeParam)(ii + 1)); } std::vector arr; std::vector sortedarr; @@ -109,7 +109,7 @@ TYPED_TEST_P(avx512_sort, test_small_range) } std::vector arrsizes; for (int64_t ii = 0; ii < 1024; ++ii) { - arrsizes.push_back((TypeParam) (ii + 1)); + arrsizes.push_back((TypeParam)(ii + 1)); } std::vector arr; std::vector sortedarr; @@ -129,4 +129,7 @@ TYPED_TEST_P(avx512_sort, test_small_range) } } REGISTER_TYPED_TEST_SUITE_P(avx512_sort, - test_random, test_reverse, test_constant, test_small_range); + test_random, + test_reverse, + test_constant, + test_small_range); diff --git a/tests/test_qsortfp16.cpp b/tests/test_qsortfp16.cpp index f86d77df..d6a45f7b 100644 --- a/tests/test_qsortfp16.cpp +++ b/tests/test_qsortfp16.cpp @@ -95,7 +95,8 @@ TEST(avx512_qselect_float16, test_arrsizes) std::sort(sortedarr.begin(), sortedarr.end()); for (size_t k = 0; k < arr.size(); ++k) { psortedarr = arr; - avx512_qselect<_Float16>(psortedarr.data(), k, psortedarr.size()); + avx512_qselect<_Float16>( + psortedarr.data(), k, psortedarr.size()); /* index k is correct */ ASSERT_EQ(sortedarr[k], psortedarr[k]); /* Check left partition */ @@ -103,7 +104,7 @@ TEST(avx512_qselect_float16, test_arrsizes) ASSERT_LE(psortedarr[jj], psortedarr[k]); } /* Check right partition */ - for (size_t jj = k+1; jj < arr.size(); jj++) { + for (size_t jj = k + 1; jj < arr.size(); jj++) { ASSERT_GE(psortedarr[jj], psortedarr[k]); } psortedarr.clear(); @@ -142,7 +143,8 @@ TEST(avx512_partial_qsort_float16, test_ranges) int k = get_uniform_rand_array(1, arrsize, 1).front(); /* Sort the range and verify all the required elements match the presorted set */ - avx512_partial_qsort<_Float16>(psortedarr.data(), k, psortedarr.size()); + avx512_partial_qsort<_Float16>( + psortedarr.data(), k, psortedarr.size()); for (size_t jj = 0; jj < k; jj++) { ASSERT_EQ(sortedarr[jj], psortedarr[jj]); } diff --git a/tests/test_sort.cpp b/tests/test_sort.cpp index 85a6bd8d..92ffbc35 100644 --- a/tests/test_sort.cpp +++ b/tests/test_sort.cpp @@ -1,6 +1,6 @@ -#include "test_qsort.hpp" -#include "test_qselect.hpp" #include "test_partial_qsort.hpp" +#include "test_qselect.hpp" +#include "test_qsort.hpp" using QuickSortTestTypes = testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_sort, QuickSortTestTypes); INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_select, QuickSortTestTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_partial_sort, QuickSortTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, + avx512_partial_sort, + QuickSortTestTypes);