From 34b6cd75d39c7095d5c05a843c82e8e1aaa6bc4a Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Sat, 29 Apr 2023 14:28:32 -0700 Subject: [PATCH 01/14] Add AVX512 argsort for int64_t --- src/avx512-64bit-argsort.hpp | 281 ++++++++++++++++ src/avx512-64bit-common.h | 19 +- src/avx512-64bit-keyvalue-networks.hpp | 437 ++++++++++++++++++++++++ src/avx512-64bit-keyvaluesort.hpp | 438 +------------------------ src/avx512-common-argsort.h | 314 ++++++++++++++++++ 5 files changed, 1051 insertions(+), 438 deletions(-) create mode 100644 src/avx512-64bit-argsort.hpp create mode 100644 src/avx512-64bit-keyvalue-networks.hpp create mode 100644 src/avx512-common-argsort.h diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp new file mode 100644 index 00000000..cfea9fbe --- /dev/null +++ b/src/avx512-64bit-argsort.hpp @@ -0,0 +1,281 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX512_ARGSORT_64BIT +#define AVX512_ARGSORT_64BIT + +#include "avx512-common-argsort.h" +#include "avx512-64bit-keyvalue-networks.hpp" + +/* argsort using std::sort */ +template +void std_argsort(T* arr, int64_t* arg, int64_t left, int64_t right) { + std::sort(arg + left, arg + right, + [arr](int64_t left, int64_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +template +X86_SIMD_SORT_INLINE void +argsort_8_64bit(type_t *arr, int64_t* arg, int32_t N) +{ + using zmm_t = typename vtype::zmm_t; + typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; + argzmm_t argzmm = argtype::maskz_loadu(load_mask, arg); + zmm_t arrzmm + = vtype::template mask_i64gather(vtype::zmm_max(), load_mask, argzmm, arr); + arrzmm = sort_zmm_64bit(arrzmm, argzmm); + vtype::mask_storeu(arg, load_mask, argzmm); +} + +template +X86_SIMD_SORT_INLINE void +argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) +{ + if (N <= 8) { + argsort_8_64bit(arr, arg, N); + return; + } + using zmm_t = typename vtype::zmm_t; + typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + argzmm_t argzmm1 = argtype::loadu(arg); + argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); + zmm_t arrzmm1 = vtype::template i64gather(argzmm1, arr); + zmm_t arrzmm2 = vtype::template mask_i64gather(vtype::zmm_max(), load_mask, argzmm2, arr); + arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); + arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); + bitonic_merge_two_zmm_64bit(arrzmm1, arrzmm2, argzmm1, argzmm2); + argtype::storeu(arg, argzmm1); + argtype::mask_storeu(arg + 8, load_mask, argzmm2); +} + +template +X86_SIMD_SORT_INLINE void +argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) +{ + if (N <= 16) { + argsort_16_64bit(arr, arg, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t arrzmm[4]; + argzmm_t argzmm[4]; + +#pragma GCC unroll 2 + for (int ii = 0; ii < 2; ++ii) { + argzmm[ii] = argtype::loadu(arg + 8*ii); + arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); + arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); + } + + uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; + opmask_t load_mask[2] = {0xFF, 0xFF}; +#pragma GCC unroll 2 + for (int ii = 0; ii < 2; ++ii) { + load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; + argzmm[ii+2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8*ii); + arrzmm[ii+2] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+2], arr); + arrzmm[ii+2] = sort_zmm_64bit(arrzmm[ii+2], argzmm[ii+2]); + } + + bitonic_merge_two_zmm_64bit( + arrzmm[0], arrzmm[1], argzmm[0], argzmm[1]); + bitonic_merge_two_zmm_64bit( + arrzmm[2], arrzmm[3], argzmm[2], argzmm[3]); + bitonic_merge_four_zmm_64bit(arrzmm, argzmm); + + argtype::storeu(arg, argzmm[0]); + argtype::storeu(arg + 8, argzmm[1]); + argtype::mask_storeu(arg + 16, load_mask[0], argzmm[2]); + argtype::mask_storeu(arg + 24, load_mask[1], argzmm[3]); +} + +template +X86_SIMD_SORT_INLINE void +argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) +{ + if (N <= 32) { + argsort_32_64bit(arr, arg, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t arrzmm[8]; + argzmm_t argzmm[8]; + +#pragma GCC unroll 4 + for (int ii = 0; ii < 4; ++ii) { + argzmm[ii] = argtype::loadu(arg + 8*ii); + arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); + arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); + } + + opmask_t load_mask[4] = {0xFF, 0xFF, 0xFF, 0xFF}; + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; +#pragma GCC unroll 4 + for (int ii = 0; ii < 4; ++ii) { + load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; + argzmm[ii+4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8*ii); + arrzmm[ii+4] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+4], arr); + arrzmm[ii+4] = sort_zmm_64bit(arrzmm[ii+4], argzmm[ii+4]); + } + +#pragma GCC unroll 4 + for (int ii = 0; ii < 8; ii = ii + 2) { + bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); + } + bitonic_merge_four_zmm_64bit(arrzmm, argzmm); + bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); + bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); + +#pragma GCC unroll 4 + for (int ii = 0; ii < 4; ++ii) { + argtype::storeu(arg + 8*ii, argzmm[ii]); + } +#pragma GCC unroll 4 + for (int ii = 0; ii < 4; ++ii) { + argtype::mask_storeu(arg + 32 + 8*ii, load_mask[ii], argzmm[ii + 4]); + } +} + +/* arsort 128 doesn't seem to make much of a difference to perf*/ +//template +//X86_SIMD_SORT_INLINE void +//argsort_128_64bit(type_t *arr, int64_t *arg, int32_t N) +//{ +// if (N <= 64) { +// argsort_64_64bit(arr, arg, N); +// return; +// } +// using zmm_t = typename vtype::zmm_t; +// using opmask_t = typename vtype::opmask_t; +// zmm_t arrzmm[16]; +// argzmm_t argzmm[16]; +// +//#pragma GCC unroll 8 +// for (int ii = 0; ii < 8; ++ii) { +// argzmm[ii] = argtype::loadu(arg + 8*ii); +// arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); +// arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); +// } +// +// opmask_t load_mask[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +// if (N != 128) { +// uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; +//#pragma GCC unroll 8 +// for (int ii = 0; ii < 8; ++ii) { +// load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; +// } +// } +//#pragma GCC unroll 8 +// for (int ii = 0; ii < 8; ++ii) { +// argzmm[ii+8] = argtype::maskz_loadu(load_mask[ii], arg + 64 + 8*ii); +// arrzmm[ii+8] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+8], arr); +// arrzmm[ii+8] = sort_zmm_64bit(arrzmm[ii+8], argzmm[ii+8]); +// } +// +//#pragma GCC unroll 8 +// for (int ii = 0; ii < 16; ii = ii + 2) { +// bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); +// } +// bitonic_merge_four_zmm_64bit(arrzmm, argzmm); +// bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); +// bitonic_merge_four_zmm_64bit(arrzmm + 8, argzmm + 8); +// bitonic_merge_four_zmm_64bit(arrzmm + 12, argzmm + 12); +// bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); +// bitonic_merge_eight_zmm_64bit(arrzmm+8, argzmm+8); +// bitonic_merge_sixteen_zmm_64bit(arrzmm, argzmm); +// +//#pragma GCC unroll 8 +// for (int ii = 0; ii < 8; ++ii) { +// argtype::storeu(arg + 8*ii, argzmm[ii]); +// } +//#pragma GCC unroll 8 +// for (int ii = 0; ii < 8; ++ii) { +// argtype::mask_storeu(arg + 64 + 8*ii, load_mask[ii], argzmm[ii + 8]); +// } +//} + +template +type_t get_pivot_64bit(type_t *arr, + int64_t* arg, + const int64_t left, + const int64_t right) +{ + if (right - left >= vtype::numlanes) { + // median of 8 + int64_t size = (right - left) / 8; + using zmm_t = typename vtype::zmm_t; + // TODO: Use gather here too: + __m512i rand_index = _mm512_set_epi64(arg[left + size], + arg[left + 2 * size], + arg[left + 3 * size], + arg[left + 4 * size], + arg[left + 5 * size], + arg[left + 6 * size], + arg[left + 7 * size], + arg[left + 8 * size]); + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + // pivot will never be a nan, since there are no nan's! + zmm_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } +} + +template +inline void +argsort_64bit_(type_t *arr, int64_t* arg, int64_t left, int64_t right, int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 64) { + argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); +} + +template <> +void avx512_argsort(int64_t *arr, int64_t* arg, int64_t arrsize) +{ + if (arrsize > 1) { + argsort_64bit_, int64_t>( + arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +std::vector avx512_argsort(int64_t *arr, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize); + return indices; +} + +#endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 75ae7fb1..1291043b 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -23,6 +23,7 @@ template <> struct zmm_vector { using type_t = int64_t; using zmm_t = __m512i; + using argzmm_t = __m512i; using ymm_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -51,11 +52,18 @@ struct zmm_vector { { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } - + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _kxor_mask8(x, y); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask8(x); } + static opmask_t le(zmm_t x, zmm_t y) + { + return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_LE); + } static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); @@ -65,6 +73,11 @@ struct zmm_vector { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); } template + static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + { + return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); + } + template static zmm_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_epi64(index, base, scale); @@ -81,6 +94,10 @@ struct zmm_vector { { return _mm512_mask_compressstoreu_epi64(mem, mask, x); } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm512_maskz_loadu_epi64(mask, mem); + } static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) { return _mm512_mask_loadu_epi64(x, mask, mem); diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp new file mode 100644 index 00000000..ff2f6da6 --- /dev/null +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -0,0 +1,437 @@ + +template +X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + key_zmm = cmp_merge( + key_zmm, + vtype1::template shuffle(key_zmm), + index_zmm, + vtype2::template shuffle(index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + index_zmm, + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, + vtype1::template shuffle(key_zmm), + index_zmm, + vtype2::template shuffle(index_zmm), + 0xAA); + key_zmm = cmp_merge( + key_zmm, + vtype1::permutexvar(rev_index, key_zmm), + index_zmm, + vtype2::permutexvar(rev_index, index_zmm), + 0xF0); + key_zmm = cmp_merge( + key_zmm, + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + 0xCC); + key_zmm = cmp_merge( + key_zmm, + vtype1::template shuffle(key_zmm), + index_zmm, + vtype2::template shuffle(index_zmm), + 0xAA); + return key_zmm; +} +// Assumes zmm is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, + index_type &index_zmm) +{ + + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + key_zmm = cmp_merge( + key_zmm, + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + index_zmm, + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + 0xF0); + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + index_zmm, + vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + 0xCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, + vtype1::template shuffle(key_zmm), + index_zmm, + vtype2::template shuffle(index_zmm), + 0xAA); + return key_zmm; +} +// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, + zmm_t &key_zmm2, + index_type &index_zmm1, + index_type &index_zmm2) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + key_zmm2 = vtype1::permutexvar(rev_index, key_zmm2); + index_zmm2 = vtype2::permutexvar(rev_index, index_zmm2); + + zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); + zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); + + index_type index_zmm3 = vtype2::mask_mov( + index_zmm2, vtype1::eq(key_zmm3, key_zmm1), index_zmm1); + index_type index_zmm4 = vtype2::mask_mov( + index_zmm1, vtype1::eq(key_zmm3, key_zmm1), index_zmm2); + + // 2) Recursive half cleaner for each + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + index_zmm1 = index_zmm3; + index_zmm2 = index_zmm4; +} +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, + index_type *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + // 1) First step of a merging network + zmm_t key_zmm2r = vtype1::permutexvar(rev_index, key_zmm[2]); + zmm_t key_zmm3r = vtype1::permutexvar(rev_index, key_zmm[3]); + index_type index_zmm2r = vtype2::permutexvar(rev_index, index_zmm[2]); + index_type index_zmm3r = vtype2::permutexvar(rev_index, index_zmm[3]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm3r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm2r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); + + // 2) Recursive half clearer: 16 + zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t4 = vtype1::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t3 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t4 = vtype2::permutexvar(rev_index, index_zmm_m1); + + zmm_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); + zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); + zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); + + index_type index_zmm0 = vtype2::mask_mov( + index_zmm_t2, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t1); + index_type index_zmm1 = vtype2::mask_mov( + index_zmm_t1, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t2); + index_type index_zmm2 = vtype2::mask_mov( + index_zmm_t4, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t3); + index_type index_zmm3 = vtype2::mask_mov( + index_zmm_t3, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + + index_zmm[0] = index_zmm0; + index_zmm[1] = index_zmm1; + index_zmm[2] = index_zmm2; + index_zmm[3] = index_zmm3; +} + +template +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, + index_type *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm4r = vtype1::permutexvar(rev_index, key_zmm[4]); + zmm_t key_zmm5r = vtype1::permutexvar(rev_index, key_zmm[5]); + zmm_t key_zmm6r = vtype1::permutexvar(rev_index, key_zmm[6]); + zmm_t key_zmm7r = vtype1::permutexvar(rev_index, key_zmm[7]); + index_type index_zmm4r = vtype2::permutexvar(rev_index, index_zmm[4]); + index_type index_zmm5r = vtype2::permutexvar(rev_index, index_zmm[5]); + index_type index_zmm6r = vtype2::permutexvar(rev_index, index_zmm[6]); + index_type index_zmm7r = vtype2::permutexvar(rev_index, index_zmm[7]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); + zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); + zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); + + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); + zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); + zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm7r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm6r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); + index_type index_zmm_t3 = vtype2::mask_mov( + index_zmm5r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); + index_type index_zmm_t4 = vtype2::mask_mov( + index_zmm4r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); + + zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t7 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t8 = vtype1::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t5 = vtype2::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t6 = vtype2::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t7 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t8 = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + key_zmm[0] + = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; +} + +template +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, + index_type *index_zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t key_zmm8r = vtype1::permutexvar(rev_index, key_zmm[8]); + zmm_t key_zmm9r = vtype1::permutexvar(rev_index, key_zmm[9]); + zmm_t key_zmm10r = vtype1::permutexvar(rev_index, key_zmm[10]); + zmm_t key_zmm11r = vtype1::permutexvar(rev_index, key_zmm[11]); + zmm_t key_zmm12r = vtype1::permutexvar(rev_index, key_zmm[12]); + zmm_t key_zmm13r = vtype1::permutexvar(rev_index, key_zmm[13]); + zmm_t key_zmm14r = vtype1::permutexvar(rev_index, key_zmm[14]); + zmm_t key_zmm15r = vtype1::permutexvar(rev_index, key_zmm[15]); + + index_type index_zmm8r = vtype2::permutexvar(rev_index, index_zmm[8]); + index_type index_zmm9r = vtype2::permutexvar(rev_index, index_zmm[9]); + index_type index_zmm10r = vtype2::permutexvar(rev_index, index_zmm[10]); + index_type index_zmm11r = vtype2::permutexvar(rev_index, index_zmm[11]); + index_type index_zmm12r = vtype2::permutexvar(rev_index, index_zmm[12]); + index_type index_zmm13r = vtype2::permutexvar(rev_index, index_zmm[13]); + index_type index_zmm14r = vtype2::permutexvar(rev_index, index_zmm[14]); + index_type index_zmm15r = vtype2::permutexvar(rev_index, index_zmm[15]); + + zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); + zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); + zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm13r); + zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm12r); + zmm_t key_zmm_t5 = vtype1::min(key_zmm[4], key_zmm11r); + zmm_t key_zmm_t6 = vtype1::min(key_zmm[5], key_zmm10r); + zmm_t key_zmm_t7 = vtype1::min(key_zmm[6], key_zmm9r); + zmm_t key_zmm_t8 = vtype1::min(key_zmm[7], key_zmm8r); + + zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); + zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); + zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); + zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); + zmm_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); + zmm_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); + zmm_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); + zmm_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); + + index_type index_zmm_t1 = vtype2::mask_mov( + index_zmm15r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); + index_type index_zmm_t2 = vtype2::mask_mov( + index_zmm14r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); + index_type index_zmm_t3 = vtype2::mask_mov( + index_zmm13r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); + index_type index_zmm_t4 = vtype2::mask_mov( + index_zmm12r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); + + index_type index_zmm_t5 = vtype2::mask_mov( + index_zmm11r, vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); + index_type index_zmm_m5 = vtype2::mask_mov( + index_zmm[4], vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); + index_type index_zmm_t6 = vtype2::mask_mov( + index_zmm10r, vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); + index_type index_zmm_m6 = vtype2::mask_mov( + index_zmm[5], vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); + index_type index_zmm_t7 = vtype2::mask_mov( + index_zmm9r, vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); + index_type index_zmm_m7 = vtype2::mask_mov( + index_zmm[6], vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); + index_type index_zmm_t8 = vtype2::mask_mov( + index_zmm8r, vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); + index_type index_zmm_m8 = vtype2::mask_mov( + index_zmm[7], vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); + + zmm_t key_zmm_t9 = vtype1::permutexvar(rev_index, key_zmm_m8); + zmm_t key_zmm_t10 = vtype1::permutexvar(rev_index, key_zmm_m7); + zmm_t key_zmm_t11 = vtype1::permutexvar(rev_index, key_zmm_m6); + zmm_t key_zmm_t12 = vtype1::permutexvar(rev_index, key_zmm_m5); + zmm_t key_zmm_t13 = vtype1::permutexvar(rev_index, key_zmm_m4); + zmm_t key_zmm_t14 = vtype1::permutexvar(rev_index, key_zmm_m3); + zmm_t key_zmm_t15 = vtype1::permutexvar(rev_index, key_zmm_m2); + zmm_t key_zmm_t16 = vtype1::permutexvar(rev_index, key_zmm_m1); + index_type index_zmm_t9 = vtype2::permutexvar(rev_index, index_zmm_m8); + index_type index_zmm_t10 = vtype2::permutexvar(rev_index, index_zmm_m7); + index_type index_zmm_t11 = vtype2::permutexvar(rev_index, index_zmm_m6); + index_type index_zmm_t12 = vtype2::permutexvar(rev_index, index_zmm_m5); + index_type index_zmm_t13 = vtype2::permutexvar(rev_index, index_zmm_m4); + index_type index_zmm_t14 = vtype2::permutexvar(rev_index, index_zmm_m3); + index_type index_zmm_t15 = vtype2::permutexvar(rev_index, index_zmm_m2); + index_type index_zmm_t16 = vtype2::permutexvar(rev_index, index_zmm_m1); + + COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); + COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); + COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); + COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); + COEX( + key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); + COEX( + key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); + COEX( + key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); + COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); + COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); + COEX( + key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); + COEX( + key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); + COEX( + key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); + + COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); + COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); + COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); + COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); + COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); + COEX( + key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); + COEX( + key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); + COEX( + key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); + // + key_zmm[0] + = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); + key_zmm[8] + = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, + index_zmm_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, + index_zmm_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, + index_zmm_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, + index_zmm_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, + index_zmm_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, + index_zmm_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, + index_zmm_t16); + + index_zmm[0] = index_zmm_t1; + index_zmm[1] = index_zmm_t2; + index_zmm[2] = index_zmm_t3; + index_zmm[3] = index_zmm_t4; + index_zmm[4] = index_zmm_t5; + index_zmm[5] = index_zmm_t6; + index_zmm[6] = index_zmm_t7; + index_zmm[7] = index_zmm_t8; + index_zmm[8] = index_zmm_t9; + index_zmm[9] = index_zmm_t10; + index_zmm[10] = index_zmm_t11; + index_zmm[11] = index_zmm_t12; + index_zmm[12] = index_zmm_t13; + index_zmm[13] = index_zmm_t14; + index_zmm[14] = index_zmm_t15; + index_zmm[15] = index_zmm_t16; +} diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 4c75c481..f721f5c8 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -9,443 +9,7 @@ #define AVX512_QSORT_64BIT_KV #include "avx512-64bit-common.h" - -template -X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - key_zmm = cmp_merge( - key_zmm, - vtype1::template shuffle(key_zmm), - index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), - index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), - 0xCC); - key_zmm = cmp_merge( - key_zmm, - vtype1::template shuffle(key_zmm), - index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(rev_index, key_zmm), - index_zmm, - vtype2::permutexvar(rev_index, index_zmm), - 0xF0); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), - index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), - 0xCC); - key_zmm = cmp_merge( - key_zmm, - vtype1::template shuffle(key_zmm), - index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); - return key_zmm; -} -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, - index_type &index_zmm) -{ - - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), - index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), - 0xF0); - // 2) half_cleaner[4] - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), - index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), - 0xCC); - // 3) half_cleaner[1] - key_zmm = cmp_merge( - key_zmm, - vtype1::template shuffle(key_zmm), - index_zmm, - vtype2::template shuffle(index_zmm), - 0xAA); - return key_zmm; -} -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, - zmm_t &key_zmm2, - index_type &index_zmm1, - index_type &index_zmm2) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype1::permutexvar(rev_index, key_zmm2); - index_zmm2 = vtype2::permutexvar(rev_index, index_zmm2); - - zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); - zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); - - index_type index_zmm3 = vtype2::mask_mov( - index_zmm2, vtype1::eq(key_zmm3, key_zmm1), index_zmm1); - index_type index_zmm4 = vtype2::mask_mov( - index_zmm1, vtype1::eq(key_zmm3, key_zmm1), index_zmm2); - - // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); - index_zmm1 = index_zmm3; - index_zmm2 = index_zmm4; -} -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, - index_type *index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - // 1) First step of a merging network - zmm_t key_zmm2r = vtype1::permutexvar(rev_index, key_zmm[2]); - zmm_t key_zmm3r = vtype1::permutexvar(rev_index, key_zmm[3]); - index_type index_zmm2r = vtype2::permutexvar(rev_index, index_zmm[2]); - index_type index_zmm3r = vtype2::permutexvar(rev_index, index_zmm[3]); - - zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm3r); - zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm2r); - zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); - zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); - - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm3r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm3r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm2r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm2r); - - // 2) Recursive half clearer: 16 - zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t4 = vtype1::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t3 = vtype2::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t4 = vtype2::permutexvar(rev_index, index_zmm_m1); - - zmm_t key_zmm0 = vtype1::min(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm1 = vtype1::max(key_zmm_t1, key_zmm_t2); - zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4); - zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4); - - index_type index_zmm0 = vtype2::mask_mov( - index_zmm_t2, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t1); - index_type index_zmm1 = vtype2::mask_mov( - index_zmm_t1, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t2); - index_type index_zmm2 = vtype2::mask_mov( - index_zmm_t4, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t3); - index_type index_zmm3 = vtype2::mask_mov( - index_zmm_t3, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - - index_zmm[0] = index_zmm0; - index_zmm[1] = index_zmm1; - index_zmm[2] = index_zmm2; - index_zmm[3] = index_zmm3; -} - -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, - index_type *index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm4r = vtype1::permutexvar(rev_index, key_zmm[4]); - zmm_t key_zmm5r = vtype1::permutexvar(rev_index, key_zmm[5]); - zmm_t key_zmm6r = vtype1::permutexvar(rev_index, key_zmm[6]); - zmm_t key_zmm7r = vtype1::permutexvar(rev_index, key_zmm[7]); - index_type index_zmm4r = vtype2::permutexvar(rev_index, index_zmm[4]); - index_type index_zmm5r = vtype2::permutexvar(rev_index, index_zmm[5]); - index_type index_zmm6r = vtype2::permutexvar(rev_index, index_zmm[6]); - index_type index_zmm7r = vtype2::permutexvar(rev_index, index_zmm[7]); - - zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm7r); - zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm6r); - zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm5r); - zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm4r); - - zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); - zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); - zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); - zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); - - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm7r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm7r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm6r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm6r); - index_type index_zmm_t3 = vtype2::mask_mov( - index_zmm5r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm5r); - index_type index_zmm_t4 = vtype2::mask_mov( - index_zmm4r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm4r); - - zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t7 = vtype1::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t8 = vtype1::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t5 = vtype2::permutexvar(rev_index, index_zmm_m4); - index_type index_zmm_t6 = vtype2::permutexvar(rev_index, index_zmm_m3); - index_type index_zmm_t7 = vtype2::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t8 = vtype2::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - key_zmm[0] - = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - - index_zmm[0] = index_zmm_t1; - index_zmm[1] = index_zmm_t2; - index_zmm[2] = index_zmm_t3; - index_zmm[3] = index_zmm_t4; - index_zmm[4] = index_zmm_t5; - index_zmm[5] = index_zmm_t6; - index_zmm[6] = index_zmm_t7; - index_zmm[7] = index_zmm_t8; -} - -template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *key_zmm, - index_type *index_zmm) -{ - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); - zmm_t key_zmm8r = vtype1::permutexvar(rev_index, key_zmm[8]); - zmm_t key_zmm9r = vtype1::permutexvar(rev_index, key_zmm[9]); - zmm_t key_zmm10r = vtype1::permutexvar(rev_index, key_zmm[10]); - zmm_t key_zmm11r = vtype1::permutexvar(rev_index, key_zmm[11]); - zmm_t key_zmm12r = vtype1::permutexvar(rev_index, key_zmm[12]); - zmm_t key_zmm13r = vtype1::permutexvar(rev_index, key_zmm[13]); - zmm_t key_zmm14r = vtype1::permutexvar(rev_index, key_zmm[14]); - zmm_t key_zmm15r = vtype1::permutexvar(rev_index, key_zmm[15]); - - index_type index_zmm8r = vtype2::permutexvar(rev_index, index_zmm[8]); - index_type index_zmm9r = vtype2::permutexvar(rev_index, index_zmm[9]); - index_type index_zmm10r = vtype2::permutexvar(rev_index, index_zmm[10]); - index_type index_zmm11r = vtype2::permutexvar(rev_index, index_zmm[11]); - index_type index_zmm12r = vtype2::permutexvar(rev_index, index_zmm[12]); - index_type index_zmm13r = vtype2::permutexvar(rev_index, index_zmm[13]); - index_type index_zmm14r = vtype2::permutexvar(rev_index, index_zmm[14]); - index_type index_zmm15r = vtype2::permutexvar(rev_index, index_zmm[15]); - - zmm_t key_zmm_t1 = vtype1::min(key_zmm[0], key_zmm15r); - zmm_t key_zmm_t2 = vtype1::min(key_zmm[1], key_zmm14r); - zmm_t key_zmm_t3 = vtype1::min(key_zmm[2], key_zmm13r); - zmm_t key_zmm_t4 = vtype1::min(key_zmm[3], key_zmm12r); - zmm_t key_zmm_t5 = vtype1::min(key_zmm[4], key_zmm11r); - zmm_t key_zmm_t6 = vtype1::min(key_zmm[5], key_zmm10r); - zmm_t key_zmm_t7 = vtype1::min(key_zmm[6], key_zmm9r); - zmm_t key_zmm_t8 = vtype1::min(key_zmm[7], key_zmm8r); - - zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); - zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); - zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); - zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); - zmm_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); - zmm_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); - zmm_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); - zmm_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); - - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm15r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm15r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm14r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm14r); - index_type index_zmm_t3 = vtype2::mask_mov( - index_zmm13r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm13r); - index_type index_zmm_t4 = vtype2::mask_mov( - index_zmm12r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm12r); - - index_type index_zmm_t5 = vtype2::mask_mov( - index_zmm11r, vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm[4]); - index_type index_zmm_m5 = vtype2::mask_mov( - index_zmm[4], vtype1::eq(key_zmm_t5, key_zmm[4]), index_zmm11r); - index_type index_zmm_t6 = vtype2::mask_mov( - index_zmm10r, vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm[5]); - index_type index_zmm_m6 = vtype2::mask_mov( - index_zmm[5], vtype1::eq(key_zmm_t6, key_zmm[5]), index_zmm10r); - index_type index_zmm_t7 = vtype2::mask_mov( - index_zmm9r, vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm[6]); - index_type index_zmm_m7 = vtype2::mask_mov( - index_zmm[6], vtype1::eq(key_zmm_t7, key_zmm[6]), index_zmm9r); - index_type index_zmm_t8 = vtype2::mask_mov( - index_zmm8r, vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm[7]); - index_type index_zmm_m8 = vtype2::mask_mov( - index_zmm[7], vtype1::eq(key_zmm_t8, key_zmm[7]), index_zmm8r); - - zmm_t key_zmm_t9 = vtype1::permutexvar(rev_index, key_zmm_m8); - zmm_t key_zmm_t10 = vtype1::permutexvar(rev_index, key_zmm_m7); - zmm_t key_zmm_t11 = vtype1::permutexvar(rev_index, key_zmm_m6); - zmm_t key_zmm_t12 = vtype1::permutexvar(rev_index, key_zmm_m5); - zmm_t key_zmm_t13 = vtype1::permutexvar(rev_index, key_zmm_m4); - zmm_t key_zmm_t14 = vtype1::permutexvar(rev_index, key_zmm_m3); - zmm_t key_zmm_t15 = vtype1::permutexvar(rev_index, key_zmm_m2); - zmm_t key_zmm_t16 = vtype1::permutexvar(rev_index, key_zmm_m1); - index_type index_zmm_t9 = vtype2::permutexvar(rev_index, index_zmm_m8); - index_type index_zmm_t10 = vtype2::permutexvar(rev_index, index_zmm_m7); - index_type index_zmm_t11 = vtype2::permutexvar(rev_index, index_zmm_m6); - index_type index_zmm_t12 = vtype2::permutexvar(rev_index, index_zmm_m5); - index_type index_zmm_t13 = vtype2::permutexvar(rev_index, index_zmm_m4); - index_type index_zmm_t14 = vtype2::permutexvar(rev_index, index_zmm_m3); - index_type index_zmm_t15 = vtype2::permutexvar(rev_index, index_zmm_m2); - index_type index_zmm_t16 = vtype2::permutexvar(rev_index, index_zmm_m1); - - COEX(key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); - COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); - COEX(key_zmm_t3, key_zmm_t7, index_zmm_t3, index_zmm_t7); - COEX(key_zmm_t4, key_zmm_t8, index_zmm_t4, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t13, index_zmm_t9, index_zmm_t13); - COEX( - key_zmm_t10, key_zmm_t14, index_zmm_t10, index_zmm_t14); - COEX( - key_zmm_t11, key_zmm_t15, index_zmm_t11, index_zmm_t15); - COEX( - key_zmm_t12, key_zmm_t16, index_zmm_t12, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); - COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t7, index_zmm_t5, index_zmm_t7); - COEX(key_zmm_t6, key_zmm_t8, index_zmm_t6, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t11, index_zmm_t9, index_zmm_t11); - COEX( - key_zmm_t10, key_zmm_t12, index_zmm_t10, index_zmm_t12); - COEX( - key_zmm_t13, key_zmm_t15, index_zmm_t13, index_zmm_t15); - COEX( - key_zmm_t14, key_zmm_t16, index_zmm_t14, index_zmm_t16); - - COEX(key_zmm_t1, key_zmm_t2, index_zmm_t1, index_zmm_t2); - COEX(key_zmm_t3, key_zmm_t4, index_zmm_t3, index_zmm_t4); - COEX(key_zmm_t5, key_zmm_t6, index_zmm_t5, index_zmm_t6); - COEX(key_zmm_t7, key_zmm_t8, index_zmm_t7, index_zmm_t8); - COEX(key_zmm_t9, key_zmm_t10, index_zmm_t9, index_zmm_t10); - COEX( - key_zmm_t11, key_zmm_t12, index_zmm_t11, index_zmm_t12); - COEX( - key_zmm_t13, key_zmm_t14, index_zmm_t13, index_zmm_t14); - COEX( - key_zmm_t15, key_zmm_t16, index_zmm_t15, index_zmm_t16); - // - key_zmm[0] - = bitonic_merge_zmm_64bit(key_zmm_t1, index_zmm_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_zmm_t2, index_zmm_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_zmm_t3, index_zmm_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_zmm_t4, index_zmm_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_zmm_t5, index_zmm_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_zmm_t6, index_zmm_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_zmm_t7, index_zmm_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_zmm_t8, index_zmm_t8); - key_zmm[8] - = bitonic_merge_zmm_64bit(key_zmm_t9, index_zmm_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_zmm_t10, - index_zmm_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_zmm_t11, - index_zmm_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_zmm_t12, - index_zmm_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_zmm_t13, - index_zmm_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_zmm_t14, - index_zmm_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_zmm_t15, - index_zmm_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_zmm_t16, - index_zmm_t16); - - index_zmm[0] = index_zmm_t1; - index_zmm[1] = index_zmm_t2; - index_zmm[2] = index_zmm_t3; - index_zmm[3] = index_zmm_t4; - index_zmm[4] = index_zmm_t5; - index_zmm[5] = index_zmm_t6; - index_zmm[6] = index_zmm_t7; - index_zmm[7] = index_zmm_t8; - index_zmm[8] = index_zmm_t9; - index_zmm[9] = index_zmm_t10; - index_zmm[10] = index_zmm_t11; - index_zmm[11] = index_zmm_t12; - index_zmm[12] = index_zmm_t13; - index_zmm[13] = index_zmm_t14; - index_zmm[14] = index_zmm_t15; - index_zmm[15] = index_zmm_t16; -} +#include "avx512-64bit-keyvalue-networks.hpp" template + * ****************************************************************/ + +#ifndef AVX512_ARGSORT_COMMON +#define AVX512_ARGSORT_COMMON + +#include "avx512-64bit-common.h" +#include +#include +#include + +using argtype = zmm_vector; +using argzmm_t = typename argtype::zmm_t; + +template +void avx512_argsort(T *arr, int64_t* arg, int64_t arrsize); + +template +std::vector avx512_argsort(T *arr, int64_t arrsize); + +/* + * COEX == Compare and Exchange two registers by swapping min and max values + */ +//template +//static void COEX(mm_t &a, mm_t &b) +//{ +// mm_t temp = a; +// a = vtype::min(a, b); +// b = vtype::max(temp, b); +//} +// +template +static inline zmm_t cmp_merge(zmm_t in1, + zmm_t in2, + argzmm_t& arg1, + argzmm_t arg2, + opmask_t mask) +{ + typename vtype::opmask_t le_mask = vtype::le(in1, in2); + opmask_t temp = vtype::kxor_opmask(le_mask, mask); + arg1 = vtype::mask_mov(arg2, temp, arg1); // 0 -> min, 1 -> max + return vtype::mask_mov(in2, temp, in1); // 0 -> min, 1 -> max + +} +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +static inline int32_t partition_vec(type_t *arg, + int64_t left, + int64_t right, + const argzmm_t arg_vec, + const zmm_t curr_vec, + const zmm_t pivot_vec, + zmm_t *smallest_vec, + zmm_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + vtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + vtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +static inline int64_t partition_avx512(type_t *arr, + int64_t* arg, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argzmm_t argvec = argtype::loadu(arg + left); + zmm_t vec = vtype::template i64gather(argvec, arr); + int32_t amount_gt_pivot = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argzmm_t argvec_left = argtype::loadu(arg + left); + zmm_t vec_left = vtype::template i64gather(argvec_left, arr); + argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + zmm_t vec_right = vtype::template i64gather(argvec_right, arr); + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + zmm_t arg_vec, curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::template i64gather(arg_vec, arr); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::template i64gather(arg_vec, arr); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +static inline int64_t partition_avx512_unrolled(type_t *arr, + int64_t* arg, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll*vtype::numlanes)); i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + zmm_t vec_left[num_unroll], vec_right[num_unroll]; + argzmm_t argvec_left[num_unroll], argvec_right[num_unroll]; +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes*ii); + vec_left[ii] = vtype::template i64gather(argvec_left[ii], arr); + argvec_right[ii] = argtype::loadu(arg + (right - vtype::numlanes*(num_unroll-ii))); + vec_right[ii] = vtype::template i64gather(argvec_right[ii], arr); + } + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argzmm_t arg_vec[num_unroll]; + zmm_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = vtype::loadu(arg + right + ii*vtype::numlanes); + curr_vec[ii] = vtype::template i64gather(arg_vec[ii], arr); + } + } + else { +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = vtype::loadu(arg + left + ii*vtype::numlanes); + curr_vec[ii] = vtype::template i64gather(arg_vec[ii], arr); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} +#endif // AVX512_ARGSORT_COMMON From 7e857597fb9ed472be6509571e808a4cd35661b4 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Sat, 29 Apr 2023 14:29:02 -0700 Subject: [PATCH 02/14] Add benchmarks for AVX-512 argsort --- benchmarks/bench-qsort-common.h | 1 + benchmarks/bench_argsort.hpp | 102 ++++++++++++++++++++++++++++++++ benchmarks/bench_qsort.cpp | 1 + 3 files changed, 104 insertions(+) create mode 100644 benchmarks/bench_argsort.hpp diff --git a/benchmarks/bench-qsort-common.h b/benchmarks/bench-qsort-common.h index 87fba479..bab70b85 100644 --- a/benchmarks/bench-qsort-common.h +++ b/benchmarks/bench-qsort-common.h @@ -7,6 +7,7 @@ #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" #include "avx512-64bit-qsort.hpp" +#include "avx512-64bit-argsort.hpp" #define MY_BENCHMARK_CAPTURE(func, T, test_case_name, ...) \ BENCHMARK_PRIVATE_DECLARE(func) \ diff --git a/benchmarks/bench_argsort.hpp b/benchmarks/bench_argsort.hpp new file mode 100644 index 00000000..ede795a3 --- /dev/null +++ b/benchmarks/bench_argsort.hpp @@ -0,0 +1,102 @@ +#include "bench-qsort-common.h" + +template +std::vector stdargsort(const std::vector &array) { + std::vector indices(array.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&array](int64_t left, int64_t right) -> bool { + // sort indices according to corresponding array element + return array[left] < array[right]; + }); + + return indices; +} + +#define MY_BENCHMARK_CAPTURE(func, T, test_case_name, ...) \ + BENCHMARK_PRIVATE_DECLARE(func) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark( \ + #func "/" #test_case_name "/" #T, \ + [](::benchmark::State& st) { func(st, __VA_ARGS__); }))) + +template +static void stdargsort(benchmark::State& state, Args&&... args) { + auto args_tuple = std::make_tuple(std::move(args)...); + // Perform setup here + size_t ARRSIZE = std::get<0>(args_tuple); + std::vector arr; + std::vector inx; + + std::string arrtype = std::get<1>(args_tuple); + if (arrtype == "random") { + arr = get_uniform_rand_array(ARRSIZE); + } + else if (arrtype == "sorted") { + arr = get_uniform_rand_array(ARRSIZE); + std::sort(arr.begin(), arr.end()); + } + else if (arrtype == "constant") { + T temp = get_uniform_rand_array(1)[0]; + for (size_t ii = 0; ii < ARRSIZE; ++ii) { + arr.push_back(temp); + } + } + else if (arrtype == "reverse") { + arr = get_uniform_rand_array(ARRSIZE); + std::sort(arr.begin(), arr.end()); + std::reverse(arr.begin(), arr.end()); + } + + /* call avx512 quicksort */ + for (auto _ : state) { + inx = stdargsort(arr); + } +} + +template +static void avx512argsort(benchmark::State& state, Args&&... args) { + auto args_tuple = std::make_tuple(std::move(args)...); + if (!cpu_has_avx512bw()) { + state.SkipWithMessage("Requires AVX512 BW ISA"); + } + // Perform setup here + size_t ARRSIZE = std::get<0>(args_tuple); + std::vector arr; + std::vector inx; + + std::string arrtype = std::get<1>(args_tuple); + if (arrtype == "random") { + arr = get_uniform_rand_array(ARRSIZE); + } + else if (arrtype == "sorted") { + arr = get_uniform_rand_array(ARRSIZE); + std::sort(arr.begin(), arr.end()); + } + else if (arrtype == "constant") { + T temp = get_uniform_rand_array(1)[0]; + for (size_t ii = 0; ii < ARRSIZE; ++ii) { + arr.push_back(temp); + } + } + else if (arrtype == "reverse") { + arr = get_uniform_rand_array(ARRSIZE); + std::sort(arr.begin(), arr.end()); + std::reverse(arr.begin(), arr.end()); + } + + /* call avx512 quicksort */ + for (auto _ : state) { + inx = avx512_argsort(arr.data(), ARRSIZE); + } +} + +#define BENCH(func, type)\ +MY_BENCHMARK_CAPTURE(func, type, random_10000, 10000, std::string("random")); \ +MY_BENCHMARK_CAPTURE(func, type, random_100000, 100000, std::string("random")); \ +MY_BENCHMARK_CAPTURE(func, type, sorted_10000, 100000, std::string("sorted")); \ +MY_BENCHMARK_CAPTURE(func, type, constant_10000, 100000, std::string("constant")); \ +MY_BENCHMARK_CAPTURE(func, type, reverse_10000, 100000, std::string("reverse")); \ + +BENCH(avx512argsort, int64_t) +BENCH(stdargsort, int64_t) diff --git a/benchmarks/bench_qsort.cpp b/benchmarks/bench_qsort.cpp index d14bf218..f6710ee9 100644 --- a/benchmarks/bench_qsort.cpp +++ b/benchmarks/bench_qsort.cpp @@ -1,3 +1,4 @@ +#include "bench_argsort.hpp" #include "bench_qsort.hpp" #include "bench_qselect.hpp" #include "bench_partial_qsort.hpp" From 63142474b936222f97415d8553114330efc20c7c Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Sat, 29 Apr 2023 14:29:18 -0700 Subject: [PATCH 03/14] Add unit tests for argsort --- tests/meson.build | 5 +- tests/test_argsort.cpp | 163 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 tests/test_argsort.cpp diff --git a/tests/meson.build b/tests/meson.build index a69222ac..2c229654 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -2,7 +2,10 @@ libtests = [] if cpp.has_argument('-march=skylake-avx512') libtests += static_library('tests_kv', - files('test_keyvalue.cpp', ), + files( + 'test_keyvalue.cpp', + 'test_argsort.cpp', + ), dependencies: gtest_dep, include_directories : [src, utils], cpp_args : ['-O3', '-march=skylake-avx512'], diff --git a/tests/test_argsort.cpp b/tests/test_argsort.cpp new file mode 100644 index 00000000..a757fbbd --- /dev/null +++ b/tests/test_argsort.cpp @@ -0,0 +1,163 @@ +/******************************************* + * * Copyright (C) 2023 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +#include "avx512-64bit-argsort.hpp" +#include "cpuinfo.h" +#include "rand_array.h" +#include +#include +#include + +template +std::vector std_argsort(const std::vector &array) { + std::vector indices(array.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&array](int left, int right) -> bool { + // sort indices according to corresponding array sizeent + return array[left] < array[right]; + }); + + return indices; +} + +TEST(avx512_argsort_64bit, test_random) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto & size : arrsizes) { + /* Random array */ + arr = get_uniform_rand_array(size); + std::vector inx1 = std_argsort(arr); + std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + ASSERT_EQ(sort1, sort2) << "Array size =" << size; + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TEST(avx512_argsort_64bit, test_constant) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto & size : arrsizes) { + /* constant array */ + auto elem = get_uniform_rand_array(1)[0]; + for (int64_t jj = 0; jj < size; ++jj) { + arr.push_back(elem); + } + std::vector inx1 = std_argsort(arr); + std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + ASSERT_EQ(sort1, sort2) << "Array size =" << size; + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TEST(avx512_argsort_64bit, test_small_range) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto & size : arrsizes) { + /* array with a smaller range of values */ + arr = get_uniform_rand_array(size, 20, 1); + std::vector inx1 = std_argsort(arr); + std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + ASSERT_EQ(sort1, sort2) << "Array size = " << size; + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TEST(avx512_argsort_64bit, test_sorted) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto & size : arrsizes) { + arr = get_uniform_rand_array(size); + std::sort(arr.begin(), arr.end()); + std::vector inx1 = std_argsort(arr); + std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + ASSERT_EQ(sort1, sort2) << "Array size =" << size; + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TEST(avx512_argsort_64bit, test_reverse) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto & size : arrsizes) { + arr = get_uniform_rand_array(size); + std::sort(arr.begin(), arr.end()); + std::reverse(arr.begin(), arr.end()); + std::vector inx1 = std_argsort(arr); + std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + ASSERT_EQ(sort1, sort2) << "Array size =" << size; + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} From 4550a081a9f31be83563f4bf06393feb0dfce251 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Sat, 29 Apr 2023 14:34:49 -0700 Subject: [PATCH 04/14] Fix formatting --- benchmarks/bench-qsort-common.h | 8 +-- benchmarks/bench_argsort.hpp | 57 +++++++++++--------- benchmarks/bench_partial_qsort.hpp | 48 +++++++++++++---- benchmarks/bench_qselect.hpp | 6 ++- benchmarks/bench_qsort.cpp | 4 +- benchmarks/bench_qsortfp16.cpp | 42 +++++++++------ src/avx512-64bit-argsort.hpp | 79 +++++++++++++++------------- src/avx512-64bit-common.h | 3 +- src/avx512-common-argsort.h | 83 ++++++++++++++++-------------- tests/test_argsort.cpp | 35 ++++++++----- 10 files changed, 218 insertions(+), 147 deletions(-) diff --git a/benchmarks/bench-qsort-common.h b/benchmarks/bench-qsort-common.h index bab70b85..f7841772 100644 --- a/benchmarks/bench-qsort-common.h +++ b/benchmarks/bench-qsort-common.h @@ -1,13 +1,13 @@ #ifndef AVX512_BENCH_COMMON #define AVX512_BENCH_COMMON -#include -#include "rand_array.h" -#include "cpuinfo.h" #include "avx512-16bit-qsort.hpp" #include "avx512-32bit-qsort.hpp" -#include "avx512-64bit-qsort.hpp" #include "avx512-64bit-argsort.hpp" +#include "avx512-64bit-qsort.hpp" +#include "cpuinfo.h" +#include "rand_array.h" +#include #define MY_BENCHMARK_CAPTURE(func, T, test_case_name, ...) \ BENCHMARK_PRIVATE_DECLARE(func) \ diff --git a/benchmarks/bench_argsort.hpp b/benchmarks/bench_argsort.hpp index ede795a3..bb6c6da4 100644 --- a/benchmarks/bench_argsort.hpp +++ b/benchmarks/bench_argsort.hpp @@ -1,10 +1,12 @@ #include "bench-qsort-common.h" -template -std::vector stdargsort(const std::vector &array) { +template +std::vector stdargsort(const std::vector &array) +{ std::vector indices(array.size()); std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), + std::sort(indices.begin(), + indices.end(), [&array](int64_t left, int64_t right) -> bool { // sort indices according to corresponding array element return array[left] < array[right]; @@ -13,15 +15,18 @@ std::vector stdargsort(const std::vector &array) { return indices; } -#define MY_BENCHMARK_CAPTURE(func, T, test_case_name, ...) \ - BENCHMARK_PRIVATE_DECLARE(func) = \ - (::benchmark::internal::RegisterBenchmarkInternal( \ - new ::benchmark::internal::FunctionBenchmark( \ - #func "/" #test_case_name "/" #T, \ - [](::benchmark::State& st) { func(st, __VA_ARGS__); }))) +#define MY_BENCHMARK_CAPTURE(func, T, test_case_name, ...) \ + BENCHMARK_PRIVATE_DECLARE(func) \ + = (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark( \ + #func "/" #test_case_name "/" #T, \ + [](::benchmark::State &st) { \ + func(st, __VA_ARGS__); \ + }))) -template -static void stdargsort(benchmark::State& state, Args&&... args) { +template +static void stdargsort(benchmark::State &state, Args &&...args) +{ auto args_tuple = std::make_tuple(std::move(args)...); // Perform setup here size_t ARRSIZE = std::get<0>(args_tuple); @@ -29,9 +34,7 @@ static void stdargsort(benchmark::State& state, Args&&... args) { std::vector inx; std::string arrtype = std::get<1>(args_tuple); - if (arrtype == "random") { - arr = get_uniform_rand_array(ARRSIZE); - } + if (arrtype == "random") { arr = get_uniform_rand_array(ARRSIZE); } else if (arrtype == "sorted") { arr = get_uniform_rand_array(ARRSIZE); std::sort(arr.begin(), arr.end()); @@ -54,8 +57,9 @@ static void stdargsort(benchmark::State& state, Args&&... args) { } } -template -static void avx512argsort(benchmark::State& state, Args&&... args) { +template +static void avx512argsort(benchmark::State &state, Args &&...args) +{ auto args_tuple = std::make_tuple(std::move(args)...); if (!cpu_has_avx512bw()) { state.SkipWithMessage("Requires AVX512 BW ISA"); @@ -66,9 +70,7 @@ static void avx512argsort(benchmark::State& state, Args&&... args) { std::vector inx; std::string arrtype = std::get<1>(args_tuple); - if (arrtype == "random") { - arr = get_uniform_rand_array(ARRSIZE); - } + if (arrtype == "random") { arr = get_uniform_rand_array(ARRSIZE); } else if (arrtype == "sorted") { arr = get_uniform_rand_array(ARRSIZE); std::sort(arr.begin(), arr.end()); @@ -91,12 +93,17 @@ static void avx512argsort(benchmark::State& state, Args&&... args) { } } -#define BENCH(func, type)\ -MY_BENCHMARK_CAPTURE(func, type, random_10000, 10000, std::string("random")); \ -MY_BENCHMARK_CAPTURE(func, type, random_100000, 100000, std::string("random")); \ -MY_BENCHMARK_CAPTURE(func, type, sorted_10000, 100000, std::string("sorted")); \ -MY_BENCHMARK_CAPTURE(func, type, constant_10000, 100000, std::string("constant")); \ -MY_BENCHMARK_CAPTURE(func, type, reverse_10000, 100000, std::string("reverse")); \ +#define BENCH(func, type) \ + MY_BENCHMARK_CAPTURE( \ + func, type, random_10000, 10000, std::string("random")); \ + MY_BENCHMARK_CAPTURE( \ + func, type, random_100000, 100000, std::string("random")); \ + MY_BENCHMARK_CAPTURE( \ + func, type, sorted_10000, 100000, std::string("sorted")); \ + MY_BENCHMARK_CAPTURE( \ + func, type, constant_10000, 100000, std::string("constant")); \ + MY_BENCHMARK_CAPTURE( \ + func, type, reverse_10000, 100000, std::string("reverse")); BENCH(avx512argsort, int64_t) BENCH(stdargsort, int64_t) diff --git a/benchmarks/bench_partial_qsort.hpp b/benchmarks/bench_partial_qsort.hpp index d54ceb31..2560107d 100644 --- a/benchmarks/bench_partial_qsort.hpp +++ b/benchmarks/bench_partial_qsort.hpp @@ -1,7 +1,8 @@ #include "bench-qsort-common.h" template -static void avx512_partial_qsort(benchmark::State& state) { +static void avx512_partial_qsort(benchmark::State &state) +{ if (!cpu_has_avx512bw()) { state.SkipWithMessage("Requires AVX512 BW ISA"); } @@ -29,7 +30,8 @@ static void avx512_partial_qsort(benchmark::State& state) { } template -static void stdpartialsort(benchmark::State& state) { +static void stdpartialsort(benchmark::State &state) +{ // Perform setup here int64_t K = state.range(0); size_t ARRSIZE = 10000; @@ -53,20 +55,48 @@ static void stdpartialsort(benchmark::State& state) { // Register the function as a benchmark BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); //BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); -BENCHMARK(avx512_partial_qsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); diff --git a/benchmarks/bench_qselect.hpp b/benchmarks/bench_qselect.hpp index fea5bea4..0f6ad8c7 100644 --- a/benchmarks/bench_qselect.hpp +++ b/benchmarks/bench_qselect.hpp @@ -1,7 +1,8 @@ #include "bench-qsort-common.h" template -static void avx512_qselect(benchmark::State& state) { +static void avx512_qselect(benchmark::State &state) +{ if (!cpu_has_avx512bw()) { state.SkipWithMessage("Requires AVX512 BW ISA"); } @@ -29,7 +30,8 @@ static void avx512_qselect(benchmark::State& state) { } template -static void stdnthelement(benchmark::State& state) { +static void stdnthelement(benchmark::State &state) +{ // Perform setup here int64_t K = state.range(0); size_t ARRSIZE = 10000; diff --git a/benchmarks/bench_qsort.cpp b/benchmarks/bench_qsort.cpp index f6710ee9..f3a8a033 100644 --- a/benchmarks/bench_qsort.cpp +++ b/benchmarks/bench_qsort.cpp @@ -1,4 +1,4 @@ -#include "bench_argsort.hpp" #include "bench_qsort.hpp" -#include "bench_qselect.hpp" +#include "bench_argsort.hpp" #include "bench_partial_qsort.hpp" +#include "bench_qselect.hpp" diff --git a/benchmarks/bench_qsortfp16.cpp b/benchmarks/bench_qsortfp16.cpp index eddd876d..9a90d9d6 100644 --- a/benchmarks/bench_qsortfp16.cpp +++ b/benchmarks/bench_qsortfp16.cpp @@ -1,10 +1,11 @@ -#include -#include "rand_array.h" -#include "cpuinfo.h" #include "avx512fp16-16bit-qsort.hpp" +#include "cpuinfo.h" +#include "rand_array.h" +#include template -static void avx512_qsort(benchmark::State& state) { +static void avx512_qsort(benchmark::State &state) +{ if (cpu_has_avx512fp16()) { // Perform setup here size_t ARRSIZE = state.range(0); @@ -13,7 +14,7 @@ static void avx512_qsort(benchmark::State& state) { /* Initialize elements */ for (size_t jj = 0; jj < ARRSIZE; ++jj) { - _Float16 temp = (float) rand() / (float)(RAND_MAX); + _Float16 temp = (float)rand() / (float)(RAND_MAX); arr.push_back(temp); } arr_bkp = arr; @@ -32,7 +33,8 @@ static void avx512_qsort(benchmark::State& state) { } template -static void stdsort(benchmark::State& state) { +static void stdsort(benchmark::State &state) +{ if (cpu_has_avx512fp16()) { // Perform setup here size_t ARRSIZE = state.range(0); @@ -40,7 +42,7 @@ static void stdsort(benchmark::State& state) { std::vector arr_bkp; for (size_t jj = 0; jj < ARRSIZE; ++jj) { - _Float16 temp = (float) rand() / (float)(RAND_MAX); + _Float16 temp = (float)rand() / (float)(RAND_MAX); arr.push_back(temp); } arr_bkp = arr; @@ -63,7 +65,8 @@ BENCHMARK(avx512_qsort<_Float16>)->Arg(10000)->Arg(1000000); BENCHMARK(stdsort<_Float16>)->Arg(10000)->Arg(1000000); template -static void avx512_qselect(benchmark::State& state) { +static void avx512_qselect(benchmark::State &state) +{ if (cpu_has_avx512fp16()) { // Perform setup here int64_t K = state.range(0); @@ -73,7 +76,7 @@ static void avx512_qselect(benchmark::State& state) { /* Initialize elements */ for (size_t jj = 0; jj < ARRSIZE; ++jj) { - _Float16 temp = (float) rand() / (float)(RAND_MAX); + _Float16 temp = (float)rand() / (float)(RAND_MAX); arr.push_back(temp); } arr_bkp = arr; @@ -93,7 +96,8 @@ static void avx512_qselect(benchmark::State& state) { } template -static void stdnthelement(benchmark::State& state) { +static void stdnthelement(benchmark::State &state) +{ if (cpu_has_avx512fp16()) { // Perform setup here int64_t K = state.range(0); @@ -103,7 +107,7 @@ static void stdnthelement(benchmark::State& state) { /* Initialize elements */ for (size_t jj = 0; jj < ARRSIZE; ++jj) { - _Float16 temp = (float) rand() / (float)(RAND_MAX); + _Float16 temp = (float)rand() / (float)(RAND_MAX); arr.push_back(temp); } arr_bkp = arr; @@ -127,7 +131,8 @@ BENCHMARK(avx512_qselect<_Float16>)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); BENCHMARK(stdnthelement<_Float16>)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); template -static void avx512_partial_qsort(benchmark::State& state) { +static void avx512_partial_qsort(benchmark::State &state) +{ if (cpu_has_avx512fp16()) { // Perform setup here int64_t K = state.range(0); @@ -137,7 +142,7 @@ static void avx512_partial_qsort(benchmark::State& state) { /* Initialize elements */ for (size_t jj = 0; jj < ARRSIZE; ++jj) { - _Float16 temp = (float) rand() / (float)(RAND_MAX); + _Float16 temp = (float)rand() / (float)(RAND_MAX); arr.push_back(temp); } arr_bkp = arr; @@ -157,7 +162,8 @@ static void avx512_partial_qsort(benchmark::State& state) { } template -static void stdpartialsort(benchmark::State& state) { +static void stdpartialsort(benchmark::State &state) +{ if (cpu_has_avx512fp16()) { // Perform setup here int64_t K = state.range(0); @@ -167,7 +173,7 @@ static void stdpartialsort(benchmark::State& state) { /* Initialize elements */ for (size_t jj = 0; jj < ARRSIZE; ++jj) { - _Float16 temp = (float) rand() / (float)(RAND_MAX); + _Float16 temp = (float)rand() / (float)(RAND_MAX); arr.push_back(temp); } arr_bkp = arr; @@ -187,5 +193,9 @@ static void stdpartialsort(benchmark::State& state) { } // Register the function as a benchmark -BENCHMARK(avx512_partial_qsort<_Float16>)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); +BENCHMARK(avx512_partial_qsort<_Float16>) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(5000); BENCHMARK(stdpartialsort<_Float16>)->Arg(10)->Arg(100)->Arg(1000)->Arg(5000); diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index cfea9fbe..6a3e458c 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,13 +7,15 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "avx512-common-argsort.h" #include "avx512-64bit-keyvalue-networks.hpp" +#include "avx512-common-argsort.h" /* argsort using std::sort */ -template -void std_argsort(T* arr, int64_t* arg, int64_t left, int64_t right) { - std::sort(arg + left, arg + right, +template +void std_argsort(T *arr, int64_t *arg, int64_t left, int64_t right) +{ + std::sort(arg + left, + arg + right, [arr](int64_t left, int64_t right) -> bool { // sort indices according to corresponding array element return arr[left] < arr[right]; @@ -21,21 +23,19 @@ void std_argsort(T* arr, int64_t* arg, int64_t left, int64_t right) { } template -X86_SIMD_SORT_INLINE void -argsort_8_64bit(type_t *arr, int64_t* arg, int32_t N) +X86_SIMD_SORT_INLINE void argsort_8_64bit(type_t *arr, int64_t *arg, int32_t N) { using zmm_t = typename vtype::zmm_t; typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; argzmm_t argzmm = argtype::maskz_loadu(load_mask, arg); - zmm_t arrzmm - = vtype::template mask_i64gather(vtype::zmm_max(), load_mask, argzmm, arr); + zmm_t arrzmm = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask, argzmm, arr); arrzmm = sort_zmm_64bit(arrzmm, argzmm); vtype::mask_storeu(arg, load_mask, argzmm); } template -X86_SIMD_SORT_INLINE void -argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) +X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) { if (N <= 8) { argsort_8_64bit(arr, arg, N); @@ -46,17 +46,18 @@ argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N) argzmm_t argzmm1 = argtype::loadu(arg); argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); zmm_t arrzmm1 = vtype::template i64gather(argzmm1, arr); - zmm_t arrzmm2 = vtype::template mask_i64gather(vtype::zmm_max(), load_mask, argzmm2, arr); + zmm_t arrzmm2 = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask, argzmm2, arr); arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); - bitonic_merge_two_zmm_64bit(arrzmm1, arrzmm2, argzmm1, argzmm2); + bitonic_merge_two_zmm_64bit( + arrzmm1, arrzmm2, argzmm1, argzmm2); argtype::storeu(arg, argzmm1); argtype::mask_storeu(arg + 8, load_mask, argzmm2); } template -X86_SIMD_SORT_INLINE void -argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) +X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) { if (N <= 16) { argsort_16_64bit(arr, arg, N); @@ -69,7 +70,7 @@ argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) #pragma GCC unroll 2 for (int ii = 0; ii < 2; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8*ii); + argzmm[ii] = argtype::loadu(arg + 8 * ii); arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); } @@ -78,10 +79,12 @@ argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) opmask_t load_mask[2] = {0xFF, 0xFF}; #pragma GCC unroll 2 for (int ii = 0; ii < 2; ++ii) { - load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; - argzmm[ii+2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8*ii); - arrzmm[ii+2] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+2], arr); - arrzmm[ii+2] = sort_zmm_64bit(arrzmm[ii+2], argzmm[ii+2]); + load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; + argzmm[ii + 2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8 * ii); + arrzmm[ii + 2] = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask[ii], argzmm[ii + 2], arr); + arrzmm[ii + 2] = sort_zmm_64bit(arrzmm[ii + 2], + argzmm[ii + 2]); } bitonic_merge_two_zmm_64bit( @@ -97,8 +100,7 @@ argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N) } template -X86_SIMD_SORT_INLINE void -argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) +X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) { if (N <= 32) { argsort_32_64bit(arr, arg, N); @@ -111,7 +113,7 @@ argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) #pragma GCC unroll 4 for (int ii = 0; ii < 4; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8*ii); + argzmm[ii] = argtype::loadu(arg + 8 * ii); arrzmm[ii] = vtype::template i64gather(argzmm[ii], arr); arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); } @@ -120,15 +122,18 @@ argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; #pragma GCC unroll 4 for (int ii = 0; ii < 4; ++ii) { - load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; - argzmm[ii+4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8*ii); - arrzmm[ii+4] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+4], arr); - arrzmm[ii+4] = sort_zmm_64bit(arrzmm[ii+4], argzmm[ii+4]); + load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; + argzmm[ii + 4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8 * ii); + arrzmm[ii + 4] = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask[ii], argzmm[ii + 4], arr); + arrzmm[ii + 4] = sort_zmm_64bit(arrzmm[ii + 4], + argzmm[ii + 4]); } #pragma GCC unroll 4 for (int ii = 0; ii < 8; ii = ii + 2) { - bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); + bitonic_merge_two_zmm_64bit( + arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); } bitonic_merge_four_zmm_64bit(arrzmm, argzmm); bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); @@ -136,11 +141,11 @@ argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) #pragma GCC unroll 4 for (int ii = 0; ii < 4; ++ii) { - argtype::storeu(arg + 8*ii, argzmm[ii]); + argtype::storeu(arg + 8 * ii, argzmm[ii]); } #pragma GCC unroll 4 for (int ii = 0; ii < 4; ++ii) { - argtype::mask_storeu(arg + 32 + 8*ii, load_mask[ii], argzmm[ii + 4]); + argtype::mask_storeu(arg + 32 + 8 * ii, load_mask[ii], argzmm[ii + 4]); } } @@ -204,7 +209,7 @@ argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N) template type_t get_pivot_64bit(type_t *arr, - int64_t* arg, + int64_t *arg, const int64_t left, const int64_t right) { @@ -221,7 +226,8 @@ type_t get_pivot_64bit(type_t *arr, arg[left + 6 * size], arg[left + 7 * size], arg[left + 8 * size]); - zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + zmm_t rand_vec + = vtype::template i64gather(rand_index, arr); // pivot will never be a nan, since there are no nan's! zmm_t sort = sort_zmm_64bit(rand_vec); return ((type_t *)&sort)[4]; @@ -232,8 +238,11 @@ type_t get_pivot_64bit(type_t *arr, } template -inline void -argsort_64bit_(type_t *arr, int64_t* arg, int64_t left, int64_t right, int64_t max_iters) +inline void argsort_64bit_(type_t *arr, + int64_t *arg, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -252,7 +261,7 @@ argsort_64bit_(type_t *arr, int64_t* arg, int64_t left, int64_t right, int64_t m type_t pivot = get_pivot_64bit(arr, arg, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512_unrolled( + int64_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); @@ -261,7 +270,7 @@ argsort_64bit_(type_t *arr, int64_t* arg, int64_t left, int64_t right, int64_t m } template <> -void avx512_argsort(int64_t *arr, int64_t* arg, int64_t arrsize) +void avx512_argsort(int64_t *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_, int64_t>( diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 1291043b..7fa3bb03 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -73,7 +73,8 @@ struct zmm_vector { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); } template - static zmm_t mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + static zmm_t + mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 704b0a5a..ceabf5ac 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -8,15 +8,15 @@ #define AVX512_ARGSORT_COMMON #include "avx512-64bit-common.h" -#include #include #include +#include using argtype = zmm_vector; using argzmm_t = typename argtype::zmm_t; template -void avx512_argsort(T *arr, int64_t* arg, int64_t arrsize); +void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); template std::vector avx512_argsort(T *arr, int64_t arrsize); @@ -35,17 +35,13 @@ std::vector avx512_argsort(T *arr, int64_t arrsize); template -static inline zmm_t cmp_merge(zmm_t in1, - zmm_t in2, - argzmm_t& arg1, - argzmm_t arg2, - opmask_t mask) +static inline zmm_t +cmp_merge(zmm_t in1, zmm_t in2, argzmm_t &arg1, argzmm_t arg2, opmask_t mask) { typename vtype::opmask_t le_mask = vtype::le(in1, in2); opmask_t temp = vtype::kxor_opmask(le_mask, mask); arg1 = vtype::mask_mov(arg2, temp, arg1); // 0 -> min, 1 -> max return vtype::mask_mov(in2, temp, in1); // 0 -> min, 1 -> max - } /* * Parition one ZMM register based on the pivot and returns the index of the @@ -66,8 +62,7 @@ static inline int32_t partition_vec(type_t *arg, int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); vtype::mask_compressstoreu( arg + left, vtype::knot_opmask(gt_mask), arg_vec); - vtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); + vtype::mask_compressstoreu(arg + right - amount_gt_pivot, gt_mask, arg_vec); *smallest_vec = vtype::min(curr_vec, *smallest_vec); *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; @@ -78,7 +73,7 @@ static inline int32_t partition_vec(type_t *arg, */ template static inline int64_t partition_avx512(type_t *arr, - int64_t* arg, + int64_t *arg, int64_t left, int64_t right, type_t pivot, @@ -123,9 +118,11 @@ static inline int64_t partition_avx512(type_t *arr, // first and last vtype::numlanes values are partitioned at the end argzmm_t argvec_left = argtype::loadu(arg + left); - zmm_t vec_left = vtype::template i64gather(argvec_left, arr); + zmm_t vec_left + = vtype::template i64gather(argvec_left, arr); argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - zmm_t vec_right = vtype::template i64gather(argvec_right, arr); + zmm_t vec_right + = vtype::template i64gather(argvec_right, arr); // store points of the vectors int64_t r_store = right - vtype::numlanes; int64_t l_store = left; @@ -192,7 +189,7 @@ template static inline int64_t partition_avx512_unrolled(type_t *arr, - int64_t* arg, + int64_t *arg, int64_t left, int64_t right, type_t pivot, @@ -204,7 +201,8 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, arr, arg, left, right, pivot, smallest, biggest); } /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll*vtype::numlanes)); i > 0; --i) { + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { *smallest = std::min(*smallest, arr[arg[left]], comparison_func); *biggest = std::max(*biggest, arr[arg[left]], comparison_func); if (!comparison_func(arr[arg[left]], pivot)) { @@ -228,10 +226,13 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, argzmm_t argvec_left[num_unroll], argvec_right[num_unroll]; #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes*ii); - vec_left[ii] = vtype::template i64gather(argvec_left[ii], arr); - argvec_right[ii] = argtype::loadu(arg + (right - vtype::numlanes*(num_unroll-ii))); - vec_right[ii] = vtype::template i64gather(argvec_right[ii], arr); + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::template i64gather( + argvec_left[ii], arr); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::template i64gather( + argvec_right[ii], arr); } // store points of the vectors int64_t r_store = right - vtype::numlanes; @@ -251,15 +252,17 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, right -= num_unroll * vtype::numlanes; #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = vtype::loadu(arg + right + ii*vtype::numlanes); - curr_vec[ii] = vtype::template i64gather(arg_vec[ii], arr); + arg_vec[ii] = vtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::template i64gather( + arg_vec[ii], arr); } } else { #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = vtype::loadu(arg + left + ii*vtype::numlanes); - curr_vec[ii] = vtype::template i64gather(arg_vec[ii], arr); + arg_vec[ii] = vtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::template i64gather( + arg_vec[ii], arr); } left += num_unroll * vtype::numlanes; } @@ -283,27 +286,29 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, /* partition and save vec_left and vec_right */ #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } diff --git a/tests/test_argsort.cpp b/tests/test_argsort.cpp index a757fbbd..b74d9550 100644 --- a/tests/test_argsort.cpp +++ b/tests/test_argsort.cpp @@ -6,15 +6,17 @@ #include "avx512-64bit-argsort.hpp" #include "cpuinfo.h" #include "rand_array.h" +#include #include #include -#include -template -std::vector std_argsort(const std::vector &array) { +template +std::vector std_argsort(const std::vector &array) +{ std::vector indices(array.size()); std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), + std::sort(indices.begin(), + indices.end(), [&array](int left, int right) -> bool { // sort indices according to corresponding array sizeent return array[left] < array[right]; @@ -31,11 +33,12 @@ TEST(avx512_argsort_64bit, test_random) arrsizes.push_back(ii); } std::vector arr; - for (auto & size : arrsizes) { + for (auto &size : arrsizes) { /* Random array */ arr = get_uniform_rand_array(size); std::vector inx1 = std_argsort(arr); - std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); @@ -58,14 +61,15 @@ TEST(avx512_argsort_64bit, test_constant) arrsizes.push_back(ii); } std::vector arr; - for (auto & size : arrsizes) { + for (auto &size : arrsizes) { /* constant array */ auto elem = get_uniform_rand_array(1)[0]; for (int64_t jj = 0; jj < size; ++jj) { arr.push_back(elem); } std::vector inx1 = std_argsort(arr); - std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); @@ -88,11 +92,12 @@ TEST(avx512_argsort_64bit, test_small_range) arrsizes.push_back(ii); } std::vector arr; - for (auto & size : arrsizes) { + for (auto &size : arrsizes) { /* array with a smaller range of values */ arr = get_uniform_rand_array(size, 20, 1); std::vector inx1 = std_argsort(arr); - std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); @@ -115,11 +120,12 @@ TEST(avx512_argsort_64bit, test_sorted) arrsizes.push_back(ii); } std::vector arr; - for (auto & size : arrsizes) { + for (auto &size : arrsizes) { arr = get_uniform_rand_array(size); std::sort(arr.begin(), arr.end()); std::vector inx1 = std_argsort(arr); - std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); @@ -142,12 +148,13 @@ TEST(avx512_argsort_64bit, test_reverse) arrsizes.push_back(ii); } std::vector arr; - for (auto & size : arrsizes) { + for (auto &size : arrsizes) { arr = get_uniform_rand_array(size); std::sort(arr.begin(), arr.end()); std::reverse(arr.begin(), arr.end()); std::vector inx1 = std_argsort(arr); - std::vector inx2 = avx512_argsort(arr.data(), arr.size()); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); From 6cd93a46e1c23a76fd5bb3dc92c95600c123f121 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 09:38:30 -0700 Subject: [PATCH 05/14] Add missing header file --- src/avx512-64bit-argsort.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 6a3e458c..514972f8 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,8 +7,9 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "avx512-64bit-keyvalue-networks.hpp" +#include "avx512-64bit-common.h" #include "avx512-common-argsort.h" +#include "avx512-64bit-keyvalue-networks.hpp" /* argsort using std::sort */ template From 8b3c862ff085b161f0aed250f80eb07654099ada Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 11:41:41 -0700 Subject: [PATCH 06/14] Re-write qsort benchmarks --- benchmarks/bench_argsort.hpp | 21 --------------------- benchmarks/bench_qsort.hpp | 18 +++++++++--------- 2 files changed, 9 insertions(+), 30 deletions(-) diff --git a/benchmarks/bench_argsort.hpp b/benchmarks/bench_argsort.hpp index bb6c6da4..cd511ab0 100644 --- a/benchmarks/bench_argsort.hpp +++ b/benchmarks/bench_argsort.hpp @@ -15,15 +15,6 @@ std::vector stdargsort(const std::vector &array) return indices; } -#define MY_BENCHMARK_CAPTURE(func, T, test_case_name, ...) \ - BENCHMARK_PRIVATE_DECLARE(func) \ - = (::benchmark::internal::RegisterBenchmarkInternal( \ - new ::benchmark::internal::FunctionBenchmark( \ - #func "/" #test_case_name "/" #T, \ - [](::benchmark::State &st) { \ - func(st, __VA_ARGS__); \ - }))) - template static void stdargsort(benchmark::State &state, Args &&...args) { @@ -93,17 +84,5 @@ static void avx512argsort(benchmark::State &state, Args &&...args) } } -#define BENCH(func, type) \ - MY_BENCHMARK_CAPTURE( \ - func, type, random_10000, 10000, std::string("random")); \ - MY_BENCHMARK_CAPTURE( \ - func, type, random_100000, 100000, std::string("random")); \ - MY_BENCHMARK_CAPTURE( \ - func, type, sorted_10000, 100000, std::string("sorted")); \ - MY_BENCHMARK_CAPTURE( \ - func, type, constant_10000, 100000, std::string("constant")); \ - MY_BENCHMARK_CAPTURE( \ - func, type, reverse_10000, 100000, std::string("reverse")); - BENCH(avx512argsort, int64_t) BENCH(stdargsort, int64_t) diff --git a/benchmarks/bench_qsort.hpp b/benchmarks/bench_qsort.hpp index 0f9c3c48..c301d6b8 100644 --- a/benchmarks/bench_qsort.hpp +++ b/benchmarks/bench_qsort.hpp @@ -80,15 +80,15 @@ static void avx512qsort(benchmark::State &state, Args &&...args) } } -#define BENCH_ALL(type)\ +#define BENCH_BOTH(type)\ BENCH(avx512qsort, type)\ BENCH(stdsort, type) -BENCH_ALL(uint64_t) -BENCH_ALL(int64_t) -BENCH_ALL(uint32_t) -BENCH_ALL(int32_t) -BENCH_ALL(uint16_t) -BENCH_ALL(int16_t) -BENCH_ALL(float) -BENCH_ALL(double) +BENCH_BOTH(uint64_t) +BENCH_BOTH(int64_t) +BENCH_BOTH(uint32_t) +BENCH_BOTH(int32_t) +BENCH_BOTH(uint16_t) +BENCH_BOTH(int16_t) +BENCH_BOTH(float) +BENCH_BOTH(double) From c4bb73485754d11beed5e9c58c886800f4cfb180 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 12:41:25 -0700 Subject: [PATCH 07/14] Add tests for uint64_t and double argsort --- tests/test_argsort.cpp | 68 ++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/tests/test_argsort.cpp b/tests/test_argsort.cpp index b74d9550..66d9a322 100644 --- a/tests/test_argsort.cpp +++ b/tests/test_argsort.cpp @@ -10,6 +10,11 @@ #include #include +template +class avx512argsort : public ::testing::Test { +}; +TYPED_TEST_SUITE_P(avx512argsort); + template std::vector std_argsort(const std::vector &array) { @@ -25,21 +30,21 @@ std::vector std_argsort(const std::vector &array) return indices; } -TEST(avx512_argsort_64bit, test_random) +TYPED_TEST_P(avx512argsort, test_random) { if (cpu_has_avx512bw()) { std::vector arrsizes; for (int64_t ii = 0; ii <= 1024; ++ii) { arrsizes.push_back(ii); } - std::vector arr; + std::vector arr; for (auto &size : arrsizes) { /* Random array */ - arr = get_uniform_rand_array(size); + arr = get_uniform_rand_array(size); std::vector inx1 = std_argsort(arr); std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); sort2.push_back(arr[inx2[jj]]); @@ -53,24 +58,24 @@ TEST(avx512_argsort_64bit, test_random) } } -TEST(avx512_argsort_64bit, test_constant) +TYPED_TEST_P(avx512argsort, test_constant) { if (cpu_has_avx512bw()) { std::vector arrsizes; for (int64_t ii = 0; ii <= 1024; ++ii) { arrsizes.push_back(ii); } - std::vector arr; + std::vector arr; for (auto &size : arrsizes) { /* constant array */ - auto elem = get_uniform_rand_array(1)[0]; + auto elem = get_uniform_rand_array(1)[0]; for (int64_t jj = 0; jj < size; ++jj) { arr.push_back(elem); } std::vector inx1 = std_argsort(arr); std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); sort2.push_back(arr[inx2[jj]]); @@ -84,21 +89,21 @@ TEST(avx512_argsort_64bit, test_constant) } } -TEST(avx512_argsort_64bit, test_small_range) +TYPED_TEST_P(avx512argsort, test_small_range) { if (cpu_has_avx512bw()) { std::vector arrsizes; for (int64_t ii = 0; ii <= 1024; ++ii) { arrsizes.push_back(ii); } - std::vector arr; + std::vector arr; for (auto &size : arrsizes) { /* array with a smaller range of values */ - arr = get_uniform_rand_array(size, 20, 1); + arr = get_uniform_rand_array(size, 20, 1); std::vector inx1 = std_argsort(arr); std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); sort2.push_back(arr[inx2[jj]]); @@ -112,21 +117,21 @@ TEST(avx512_argsort_64bit, test_small_range) } } -TEST(avx512_argsort_64bit, test_sorted) +TYPED_TEST_P(avx512argsort, test_sorted) { if (cpu_has_avx512bw()) { std::vector arrsizes; for (int64_t ii = 0; ii <= 1024; ++ii) { arrsizes.push_back(ii); } - std::vector arr; + std::vector arr; for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); + arr = get_uniform_rand_array(size); std::sort(arr.begin(), arr.end()); std::vector inx1 = std_argsort(arr); std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); sort2.push_back(arr[inx2[jj]]); @@ -140,22 +145,22 @@ TEST(avx512_argsort_64bit, test_sorted) } } -TEST(avx512_argsort_64bit, test_reverse) +TYPED_TEST_P(avx512argsort, test_reverse) { if (cpu_has_avx512bw()) { std::vector arrsizes; for (int64_t ii = 0; ii <= 1024; ++ii) { arrsizes.push_back(ii); } - std::vector arr; + std::vector arr; for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); + arr = get_uniform_rand_array(size); std::sort(arr.begin(), arr.end()); std::reverse(arr.begin(), arr.end()); std::vector inx1 = std_argsort(arr); std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; for (size_t jj = 0; jj < size; ++jj) { sort1.push_back(arr[inx1[jj]]); sort2.push_back(arr[inx2[jj]]); @@ -168,3 +173,16 @@ TEST(avx512_argsort_64bit, test_reverse) GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; } } + +REGISTER_TYPED_TEST_SUITE_P(avx512argsort, + test_random, + test_reverse, + test_constant, + test_sorted, + test_small_range); + +using ArgSortTestTypes = testing::Types; + +INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512argsort, ArgSortTestTypes); From 469a4e439caaa6789a3ef0f83a75b3c104b43773 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 12:42:08 -0700 Subject: [PATCH 08/14] Add argsort for uint64_t and double --- src/avx512-64bit-argsort.hpp | 38 +++++++++++++++++++++++++++++++++++- src/avx512-64bit-common.h | 13 +++++++++++- src/avx512-common-argsort.h | 11 ++++++----- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 514972f8..5c21393b 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -32,7 +32,7 @@ X86_SIMD_SORT_INLINE void argsort_8_64bit(type_t *arr, int64_t *arg, int32_t N) zmm_t arrzmm = vtype::template mask_i64gather( vtype::zmm_max(), load_mask, argzmm, arr); arrzmm = sort_zmm_64bit(arrzmm, argzmm); - vtype::mask_storeu(arg, load_mask, argzmm); + argtype::mask_storeu(arg, load_mask, argzmm); } template @@ -270,6 +270,42 @@ inline void argsort_64bit_(type_t *arr, argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); } +template <> +void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize) +{ + if (arrsize > 1) { + argsort_64bit_, double>( + arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +std::vector avx512_argsort(double *arr, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize); + return indices; +} + +template <> +void avx512_argsort(uint64_t *arr, int64_t *arg, int64_t arrsize) +{ + if (arrsize > 1) { + argsort_64bit_, uint64_t>( + arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +std::vector avx512_argsort(uint64_t *arr, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize); + return indices; +} + template <> void avx512_argsort(int64_t *arr, int64_t *arg, int64_t arrsize) { diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 7fa3bb03..903ac45f 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -175,7 +175,12 @@ struct zmm_vector { { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } - + template + static zmm_t + mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + { + return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); + } template static zmm_t i64gather(__m512i index, void const *base) { @@ -295,6 +300,12 @@ struct zmm_vector { return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); } template + static zmm_t + mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + { + return _mm512_mask_i64gather_pd(src, mask, index, base, scale); + } + template static zmm_t i64gather(__m512i index, void const *base) { return _mm512_i64gather_pd(index, base, scale); diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index ceabf5ac..4f586959 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -60,9 +60,9 @@ static inline int32_t partition_vec(type_t *arg, /* which elements are larger than the pivot */ typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - vtype::mask_compressstoreu( + argtype::mask_compressstoreu( arg + left, vtype::knot_opmask(gt_mask), arg_vec); - vtype::mask_compressstoreu(arg + right - amount_gt_pivot, gt_mask, arg_vec); + argtype::mask_compressstoreu(arg + right - amount_gt_pivot, gt_mask, arg_vec); *smallest_vec = vtype::min(curr_vec, *smallest_vec); *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; @@ -130,7 +130,8 @@ static inline int64_t partition_avx512(type_t *arr, left += vtype::numlanes; right -= vtype::numlanes; while (right - left != 0) { - zmm_t arg_vec, curr_vec; + argzmm_t arg_vec; + zmm_t curr_vec; /* * if fewer elements are stored on the right side of the array, * then next elements are loaded from the right side, @@ -252,7 +253,7 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, right -= num_unroll * vtype::numlanes; #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = vtype::loadu(arg + right + ii * vtype::numlanes); + arg_vec[ii] = argtype::loadu(arg + right + ii * vtype::numlanes); curr_vec[ii] = vtype::template i64gather( arg_vec[ii], arr); } @@ -260,7 +261,7 @@ static inline int64_t partition_avx512_unrolled(type_t *arr, else { #pragma GCC unroll 8 for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = vtype::loadu(arg + left + ii * vtype::numlanes); + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); curr_vec[ii] = vtype::template i64gather( arg_vec[ii], arr); } From be5b1c2ade644bcd98657bd4b04925bb585c328e Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 12:44:39 -0700 Subject: [PATCH 09/14] Add benchmarks for uint64_t and double argsort --- benchmarks/bench_argsort.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_argsort.hpp b/benchmarks/bench_argsort.hpp index cd511ab0..b21618d7 100644 --- a/benchmarks/bench_argsort.hpp +++ b/benchmarks/bench_argsort.hpp @@ -84,5 +84,11 @@ static void avx512argsort(benchmark::State &state, Args &&...args) } } -BENCH(avx512argsort, int64_t) -BENCH(stdargsort, int64_t) +#define BENCH_BOTH(type)\ + BENCH(avx512argsort, type)\ + BENCH(stdargsort, type)\ + +BENCH_BOTH(int64_t) +BENCH_BOTH(uint64_t) +BENCH_BOTH(double) + From c2f2423bd805c7e2d2d27999c4c7e8ebb75e36bd Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 13:13:35 -0700 Subject: [PATCH 10/14] Use templates to write avx512_argsort functions --- src/avx512-64bit-argsort.hpp | 48 +++++------------------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 5c21393b..df2f3933 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -270,57 +270,21 @@ inline void argsort_64bit_(type_t *arr, argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); } -template <> -void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - argsort_64bit_, double>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -std::vector avx512_argsort(double *arr, int64_t arrsize) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize); - return indices; -} - -template <> -void avx512_argsort(uint64_t *arr, int64_t *arg, int64_t arrsize) -{ - if (arrsize > 1) { - argsort_64bit_, uint64_t>( - arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); - } -} - -template <> -std::vector avx512_argsort(uint64_t *arr, int64_t arrsize) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize); - return indices; -} - -template <> -void avx512_argsort(int64_t *arr, int64_t *arg, int64_t arrsize) +template +void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { - argsort_64bit_, int64_t>( + argsort_64bit_>( arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } -template <> -std::vector avx512_argsort(int64_t *arr, int64_t arrsize) +template +std::vector avx512_argsort(T* arr, int64_t arrsize) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize); + avx512_argsort(arr, indices.data(), arrsize); return indices; } From 677cb4abbd81fedf7428dd1d70d0e900c890f25c Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 1 May 2023 14:22:02 -0700 Subject: [PATCH 11/14] Add argsort for 32-bit data type --- src/avx512-64bit-argsort.hpp | 27 ++ src/avx512-64bit-common.h | 465 +++++++++++++++++++++++-- src/avx512-64bit-keyvalue-networks.hpp | 151 ++++---- src/avx512-common-argsort.h | 22 -- src/avx512-common-qsort.h | 3 + tests/test_argsort.cpp | 5 +- 6 files changed, 548 insertions(+), 125 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index df2f3933..1b002d71 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -279,6 +279,33 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) } } +template <> +void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize) +{ + if (arrsize > 1) { + argsort_64bit_>( + arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize) +{ + if (arrsize > 1) { + argsort_64bit_>( + arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize) +{ + if (arrsize > 1) { + argsort_64bit_>( + arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + template std::vector avx512_argsort(T* arr, int64_t arrsize) { diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 903ac45f..e3377ffd 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -19,11 +19,416 @@ #define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 #define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 +template <> +struct ymm_vector { + using type_t = float; + using zmm_t = __m256; + using zmmi_t = __m256i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static zmm_t zmm_max() + { + return _mm256_set1_ps(type_max()); + } + + static zmmi_t seti(int v1, + int v2, + int v3, + int v4, + int v5, + int v6, + int v7, + int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _kxor_mask8(x, y); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t le(zmm_t x, zmm_t y) + { + return _mm256_cmp_ps_mask(x, y, _CMP_LE_OQ); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm256_cmp_ps_mask(x, y, _CMP_GE_OQ); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ); + } + template + static zmm_t + mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + { + return _mm512_mask_i64gather_ps(src, mask, index, base, scale); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_ps(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm256_loadu_ps((float*) mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm256_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm256_mask_compressstoreu_ps(mem, mask, x); + } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskz_loadu_ps(mask, mem); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm256_mask_loadu_ps(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm256_mask_mov_ps(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm256_mask_storeu_ps(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm256_min_ps(x, y); + } + static zmm_t permutexvar(__m256i idx, zmm_t zmm) + { + return _mm256_permutexvar_ps(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + __m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps (v, 1)); + __m128 v64 = _mm_max_ps(v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128 v32 = _mm_max_ps(v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); + return _mm_cvtss_f32(v32); + } + static type_t reducemin(zmm_t v) + { + __m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps(v, 1)); + __m128 v64 = _mm_min_ps(v128, _mm_shuffle_ps(v128, v128,_MM_SHUFFLE(1, 0, 3, 2))); + __m128 v32 = _mm_min_ps(v64, _mm_shuffle_ps(v64, v64,_MM_SHUFFLE(0, 0, 0, 1))); + return _mm_cvtss_f32(v32); + } + static zmm_t set1(type_t v) + { + return _mm256_set1_ps(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + /* Hack!: have to make shuffles within 128-bit lanes work for both + * 32-bit and 64-bit */ + if constexpr (mask == 0b01010101) { + return _mm256_shuffle_ps(zmm, zmm, 0b10110001); + } + else { + /* Not used, so far */ + return _mm256_shuffle_ps(zmm, zmm, mask); + } + } + static void storeu(void *mem, zmm_t x) + { + return _mm256_storeu_ps((float*)mem, x); + } +}; +template <> +struct ymm_vector { + using type_t = uint32_t; + using zmm_t = __m256i; + using zmmi_t = __m256i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static zmm_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } + + static zmmi_t seti(int v1, + int v2, + int v3, + int v4, + int v5, + int v6, + int v7, + int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _kxor_mask8(x, y); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t le(zmm_t x, zmm_t y) + { + return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_LE); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t + mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + { + return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm256_loadu_epi32(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm256_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm256_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskz_loadu_epi32(mask, mem); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm256_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm256_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm256_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm256_min_epu32(x, y); + } + static zmm_t permutexvar(__m256i idx, zmm_t zmm) + { + return _mm256_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + __m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_max_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_max_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + return (type_t)_mm_cvtsi128_si32(v32); + } + static type_t reducemin(zmm_t v) + { + __m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_min_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_min_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + return (type_t)_mm_cvtsi128_si32(v32); + } + static zmm_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + /* Hack!: have to make shuffles within 128-bit lanes work for both + * 32-bit and 64-bit */ + if constexpr (mask == 0b01010101) { + return _mm256_shuffle_epi32(zmm, 0b10110001); + } + else { + /* Not used, so far */ + return _mm256_shuffle_epi32(zmm, mask); + } + } + static void storeu(void *mem, zmm_t x) + { + return _mm256_storeu_epi32(mem, x); + } +}; +template <> +struct ymm_vector { + using type_t = int32_t; + using zmm_t = __m256i; + using zmmi_t = __m256i; + using opmask_t = __mmask8; + static const uint8_t numlanes = 8; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static zmm_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + + static zmmi_t seti(int v1, + int v2, + int v3, + int v4, + int v5, + int v6, + int v7, + int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _kxor_mask8(x, y); + } + static opmask_t knot_opmask(opmask_t x) + { + return _knot_mask8(x); + } + static opmask_t le(zmm_t x, zmm_t y) + { + return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_LE); + } + static opmask_t ge(zmm_t x, zmm_t y) + { + return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); + } + static opmask_t eq(zmm_t x, zmm_t y) + { + return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_EQ); + } + template + static zmm_t + mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base) + { + return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); + } + template + static zmm_t i64gather(__m512i index, void const *base) + { + return _mm512_i64gather_epi32(index, base, scale); + } + static zmm_t loadu(void const *mem) + { + return _mm256_loadu_epi32(mem); + } + static zmm_t max(zmm_t x, zmm_t y) + { + return _mm256_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) + { + return _mm256_mask_compressstoreu_epi32(mem, mask, x); + } + static zmm_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskz_loadu_epi32(mask, mem); + } + static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) + { + return _mm256_mask_loadu_epi32(x, mask, mem); + } + static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) + { + return _mm256_mask_mov_epi32(x, mask, y); + } + static void mask_storeu(void *mem, opmask_t mask, zmm_t x) + { + return _mm256_mask_storeu_epi32(mem, mask, x); + } + static zmm_t min(zmm_t x, zmm_t y) + { + return _mm256_min_epi32(x, y); + } + static zmm_t permutexvar(__m256i idx, zmm_t zmm) + { + return _mm256_permutexvar_epi32(idx, zmm); + } + static type_t reducemax(zmm_t v) + { + __m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_max_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_max_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + return (type_t)_mm_cvtsi128_si32(v32); + } + static type_t reducemin(zmm_t v) + { + __m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1)); + __m128i v64 = _mm_min_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i v32 = _mm_min_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); + return (type_t)_mm_cvtsi128_si32(v32); + } + static zmm_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static zmm_t shuffle(zmm_t zmm) + { + /* Hack!: have to make shuffles within 128-bit lanes work for both + * 32-bit and 64-bit */ + if constexpr (mask == 0b01010101) { + return _mm256_shuffle_epi32(zmm, 0b10110001); + } + else { + /* Not used, so far */ + return _mm256_shuffle_epi32(zmm, mask); + } + } + static void storeu(void *mem, zmm_t x) + { + return _mm256_storeu_epi32(mem, x); + } +}; template <> struct zmm_vector { using type_t = int64_t; using zmm_t = __m512i; - using argzmm_t = __m512i; + using zmmi_t = __m512i; using ymm_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -41,14 +446,14 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } // TODO: this should broadcast bits as is? - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) + static zmmi_t seti(int v1, + int v2, + int v3, + int v4, + int v5, + int v6, + int v7, + int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -147,6 +552,7 @@ template <> struct zmm_vector { using type_t = uint64_t; using zmm_t = __m512i; + using zmmi_t = __m512i; using ymm_t = __m512i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -164,14 +570,14 @@ struct zmm_vector { return _mm512_set1_epi64(type_max()); } - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) + static zmmi_t seti(int v1, + int v2, + int v3, + int v4, + int v5, + int v6, + int v7, + int v8) { return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } @@ -258,6 +664,7 @@ template <> struct zmm_vector { using type_t = double; using zmm_t = __m512d; + using zmmi_t = __m512i; using ymm_t = __m512d; using opmask_t = __mmask8; static const uint8_t numlanes = 8; @@ -275,16 +682,16 @@ struct zmm_vector { return _mm512_set1_pd(type_max()); } - static zmm_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4, - type_t v5, - type_t v6, - type_t v7, - type_t v8) + static zmmi_t seti(int v1, + int v2, + int v3, + int v4, + int v5, + int v6, + int v7, + int v8) { - return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8); + return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8); } static opmask_t knot_opmask(opmask_t x) @@ -395,19 +802,19 @@ replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count) template X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + const typename vtype::zmmi_t rev_index = vtype::seti(NETWORK_64BIT_2); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge( zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), zmm), + vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); zmm = cmp_merge( zmm, - vtype::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), zmm), + vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); zmm = cmp_merge( zmm, vtype::template shuffle(zmm), 0xAA); diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp index ff2f6da6..69d50c52 100644 --- a/src/avx512-64bit-keyvalue-networks.hpp +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -5,7 +5,8 @@ template X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); key_zmm = cmp_merge( key_zmm, vtype1::template shuffle(key_zmm), @@ -14,9 +15,9 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) 0xAA); key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), key_zmm), + vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_1), key_zmm), index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_1), index_zmm), + vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_1), index_zmm), 0xCC); key_zmm = cmp_merge( key_zmm, @@ -26,15 +27,15 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t key_zmm, index_type &index_zmm) 0xAA); key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(rev_index, key_zmm), + vtype1::permutexvar(rev_index1, key_zmm), index_zmm, - vtype2::permutexvar(rev_index, index_zmm), + vtype2::permutexvar(rev_index2, index_zmm), 0xF0); key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), 0xCC); key_zmm = cmp_merge( key_zmm, @@ -56,16 +57,16 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t key_zmm, // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), key_zmm), + vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_4), key_zmm), index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_4), index_zmm), + vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_4), index_zmm), 0xF0); // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, - vtype1::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), key_zmm), + vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), index_zmm, - vtype2::permutexvar(_mm512_set_epi64(NETWORK_64BIT_3), index_zmm), + vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), 0xCC); // 3) half_cleaner[1] key_zmm = cmp_merge( @@ -86,10 +87,11 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1, index_type &index_zmm1, index_type &index_zmm2) { - const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + const typename vtype1::zmmi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::zmmi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype1::permutexvar(rev_index, key_zmm2); - index_zmm2 = vtype2::permutexvar(rev_index, index_zmm2); + key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); + index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); @@ -114,12 +116,13 @@ template (key_zmm_t1, key_zmm_t3, index_zmm_t1, index_zmm_t3); COEX(key_zmm_t2, key_zmm_t4, index_zmm_t2, index_zmm_t4); @@ -261,24 +265,25 @@ template (key_zmm_t1, key_zmm_t5, index_zmm_t1, index_zmm_t5); COEX(key_zmm_t2, key_zmm_t6, index_zmm_t2, index_zmm_t6); diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 4f586959..e0dcaccc 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -21,28 +21,6 @@ void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); template std::vector avx512_argsort(T *arr, int64_t arrsize); -/* - * COEX == Compare and Exchange two registers by swapping min and max values - */ -//template -//static void COEX(mm_t &a, mm_t &b) -//{ -// mm_t temp = a; -// a = vtype::min(a, b); -// b = vtype::max(temp, b); -//} -// -template -static inline zmm_t -cmp_merge(zmm_t in1, zmm_t in2, argzmm_t &arg1, argzmm_t arg2, opmask_t mask) -{ - typename vtype::opmask_t le_mask = vtype::le(in1, in2); - opmask_t temp = vtype::kxor_opmask(le_mask, mask); - arg1 = vtype::mask_mov(arg2, temp, arg1); // 0 -> min, 1 -> max - return vtype::mask_mov(in2, temp, in1); // 0 -> min, 1 -> max -} /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index b07b34d2..959352e6 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -88,6 +88,9 @@ template struct zmm_vector; +template +struct ymm_vector; + // Regular quicksort routines: template void avx512_qsort(T *arr, int64_t arrsize); diff --git a/tests/test_argsort.cpp b/tests/test_argsort.cpp index 66d9a322..b9393d20 100644 --- a/tests/test_argsort.cpp +++ b/tests/test_argsort.cpp @@ -181,7 +181,10 @@ REGISTER_TYPED_TEST_SUITE_P(avx512argsort, test_sorted, test_small_range); -using ArgSortTestTypes = testing::Types; From ae29f55171bec84fe6acff78384de0cc6139cd54 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 2 May 2023 14:08:49 -0700 Subject: [PATCH 12/14] Add benchmarks for 32-bit argsort --- benchmarks/bench_argsort.hpp | 4 +++- benchmarks/bench_qsort.hpp | 18 +++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/benchmarks/bench_argsort.hpp b/benchmarks/bench_argsort.hpp index b21618d7..24ffea76 100644 --- a/benchmarks/bench_argsort.hpp +++ b/benchmarks/bench_argsort.hpp @@ -91,4 +91,6 @@ static void avx512argsort(benchmark::State &state, Args &&...args) BENCH_BOTH(int64_t) BENCH_BOTH(uint64_t) BENCH_BOTH(double) - +BENCH_BOTH(int32_t) +BENCH_BOTH(uint32_t) +BENCH_BOTH(float) diff --git a/benchmarks/bench_qsort.hpp b/benchmarks/bench_qsort.hpp index c301d6b8..ae02ac9d 100644 --- a/benchmarks/bench_qsort.hpp +++ b/benchmarks/bench_qsort.hpp @@ -80,15 +80,15 @@ static void avx512qsort(benchmark::State &state, Args &&...args) } } -#define BENCH_BOTH(type)\ +#define BENCH_BOTH_QSORT(type)\ BENCH(avx512qsort, type)\ BENCH(stdsort, type) -BENCH_BOTH(uint64_t) -BENCH_BOTH(int64_t) -BENCH_BOTH(uint32_t) -BENCH_BOTH(int32_t) -BENCH_BOTH(uint16_t) -BENCH_BOTH(int16_t) -BENCH_BOTH(float) -BENCH_BOTH(double) +BENCH_BOTH_QSORT(uint64_t) +BENCH_BOTH_QSORT(int64_t) +BENCH_BOTH_QSORT(uint32_t) +BENCH_BOTH_QSORT(int32_t) +BENCH_BOTH_QSORT(uint16_t) +BENCH_BOTH_QSORT(int16_t) +BENCH_BOTH_QSORT(float) +BENCH_BOTH_QSORT(double) From bab65af523ef72a1ccb7772a6f4454955665eb10 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 2 May 2023 14:33:17 -0700 Subject: [PATCH 13/14] Get rid of if constexpr --- src/avx512-64bit-common.h | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index e3377ffd..7f4c5bb5 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -136,18 +136,18 @@ struct ymm_vector { { return _mm256_set1_ps(v); } - template + template static zmm_t shuffle(zmm_t zmm) { /* Hack!: have to make shuffles within 128-bit lanes work for both * 32-bit and 64-bit */ - if constexpr (mask == 0b01010101) { - return _mm256_shuffle_ps(zmm, zmm, 0b10110001); - } - else { - /* Not used, so far */ - return _mm256_shuffle_ps(zmm, zmm, mask); - } + return _mm256_shuffle_ps(zmm, zmm, 0b10110001); + //if constexpr (mask == 0b01010101) { + //} + //else { + // /* Not used, so far */ + // return _mm256_shuffle_ps(zmm, zmm, mask); + //} } static void storeu(void *mem, zmm_t x) { @@ -271,18 +271,12 @@ struct ymm_vector { { return _mm256_set1_epi32(v); } - template + template static zmm_t shuffle(zmm_t zmm) { /* Hack!: have to make shuffles within 128-bit lanes work for both * 32-bit and 64-bit */ - if constexpr (mask == 0b01010101) { - return _mm256_shuffle_epi32(zmm, 0b10110001); - } - else { - /* Not used, so far */ - return _mm256_shuffle_epi32(zmm, mask); - } + return _mm256_shuffle_epi32(zmm, 0b10110001); } static void storeu(void *mem, zmm_t x) { @@ -406,18 +400,12 @@ struct ymm_vector { { return _mm256_set1_epi32(v); } - template + template static zmm_t shuffle(zmm_t zmm) { /* Hack!: have to make shuffles within 128-bit lanes work for both * 32-bit and 64-bit */ - if constexpr (mask == 0b01010101) { - return _mm256_shuffle_epi32(zmm, 0b10110001); - } - else { - /* Not used, so far */ - return _mm256_shuffle_epi32(zmm, mask); - } + return _mm256_shuffle_epi32(zmm, 0b10110001); } static void storeu(void *mem, zmm_t x) { From 353dc3c7ff239d35ee25bcaa015e35b0253a97aa Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 2 May 2023 14:37:14 -0700 Subject: [PATCH 14/14] Use _mm256_loadu_si256 to make it work on g++-10 --- Makefile | 2 +- src/avx512-64bit-common.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 6169ec1b..7d06d931 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CXX = g++-12 +CXX ?= g++-12 SRCDIR = ./src TESTDIR = ./tests BENCHDIR = ./benchmarks diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 7f4c5bb5..fb3c1f17 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -219,7 +219,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_epi32(mem); + return _mm256_loadu_si256((__m256i*) mem); } static zmm_t max(zmm_t x, zmm_t y) { @@ -348,7 +348,7 @@ struct ymm_vector { } static zmm_t loadu(void const *mem) { - return _mm256_loadu_epi32(mem); + return _mm256_loadu_si256((__m256i*) mem); } static zmm_t max(zmm_t x, zmm_t y) {