diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 009819b4..81d7d00e 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,5 +1,6 @@ // AVX2 specific routines: #include "avx2-32bit-qsort.hpp" +#include "avx2-64bit-qsort.hpp" #include "x86simdsort-internal.h" #define DEFINE_ALL_METHODS(type) \ @@ -24,5 +25,8 @@ namespace avx2 { DEFINE_ALL_METHODS(uint32_t) DEFINE_ALL_METHODS(int32_t) DEFINE_ALL_METHODS(float) + DEFINE_ALL_METHODS(uint64_t) + DEFINE_ALL_METHODS(int64_t) + DEFINE_ALL_METHODS(double) } // namespace avx2 } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 4e9ef136..0ec54bef 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -150,15 +150,15 @@ DISPATCH(argselect, _Float16, ISA_LIST("none")) DISPATCH_ALL(qsort, (ISA_LIST("avx512_icl")), (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(qselect, (ISA_LIST("avx512_icl")), (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(partial_qsort, (ISA_LIST("avx512_icl")), (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argsort, (ISA_LIST("none")), (ISA_LIST("avx512_skx")), diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index 5dd77a27..5512c310 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -135,7 +135,7 @@ struct avx2_vector { } static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { - return avx2_emu_mask_compressstoreu(mem, mask, x); + return avx2_emu_mask_compressstoreu32(mem, mask, x); } static reg_t maskz_loadu(opmask_t mask, void const *mem) { @@ -289,7 +289,7 @@ struct avx2_vector { } static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { - return avx2_emu_mask_compressstoreu(mem, mask, x); + return avx2_emu_mask_compressstoreu32(mem, mask, x); } static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { @@ -459,7 +459,7 @@ struct avx2_vector { } static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) { - return avx2_emu_mask_compressstoreu(mem, mask, x); + return avx2_emu_mask_compressstoreu32(mem, mask, x); } static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp new file mode 100644 index 00000000..2ef7f70a --- /dev/null +++ b/src/avx2-64bit-qsort.hpp @@ -0,0 +1,586 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * Matthew Sterrett + * ****************************************************************/ + +#ifndef AVX2_QSORT_64BIT +#define AVX2_QSORT_64BIT + +#include "xss-common-qsort.h" +#include "avx2-emu-funcs.hpp" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ymm 3, 2, 1, 0 +#define NETWORK_64BIT_R 0, 1, 2, 3 +#define NETWORK_64BIT_1 1, 0, 3, 2 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t ymm) +{ + const typename vtype::opmask_t oxAA + = _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0, 0xFFFFFFFFFFFFFFFF, 0); + const typename vtype::opmask_t oxCC + = _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0); + ymm = cmp_merge( + ymm, + vtype::template permutexvar(ymm), + oxAA); + ymm = cmp_merge( + ymm, + vtype::template permutexvar(ymm), + oxCC); + ymm = cmp_merge( + ymm, + vtype::template permutexvar(ymm), + oxAA); + return ymm; +} + +struct avx2_64bit_swizzle_ops; + +template <> +struct avx2_vector { + using type_t = int64_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 4; + static constexpr int network_sort_threshold = 64; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_64bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT64; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT64; + } + static reg_t zmm_max() + { + return _mm256_set1_epi64x(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_64bit(mask); + } + static ymmi_t seti(int v1, int v2, int v3, int v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm256_xor_si256(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm256_cmpgt_epi64(x, y); + return _mm256_castpd_si256(_mm256_or_pd(_mm256_castsi256_pd(equal), + _mm256_castsi256_pd(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi64(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i64gather_epi64((int64_t const *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return avx2_emu_max(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu64(mem, mask, x); + } + static int32_t double_compressstore(void *left_addr, + void *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore64( + left_addr, right_addr, k, reg); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_epi64((const long long int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi64((long long int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), + _mm256_castsi256_pd(y), + _mm256_castsi256_pd(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi64((long long int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return avx2_emu_min(x, y); + } + template + static reg_t permutexvar(reg_t ymm) + { + return _mm256_permute4x64_epi64(ymm, idx); + } + template + static reg_t permutevar(reg_t ymm) + { + return _mm256_permute4x64_epi64(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const int32_t rev_index = SHUFFLE_MASK(0, 1, 2, 3); + return permutexvar(ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi64(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max64(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min64(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi64x(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_castpd_si256( + _mm256_permute_pd(_mm256_castsi256_pd(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_64bit>(x); + } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } +}; +template <> +struct avx2_vector { + using type_t = uint64_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 4; + static constexpr int network_sort_threshold = 64; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_64bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT64; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm256_set1_epi64x(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_64bit(mask); + } + static ymmi_t seti(int v1, int v2, int v3, int v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i64gather_epi64( + (long long int const *)base, index, scale); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + + const __m256i offset = _mm256_set1_epi64x(0x8000000000000000); + x = _mm256_add_epi64(x, offset); + y = _mm256_add_epi64(y, offset); + + opmask_t greater = _mm256_cmpgt_epi64(x, y); + return _mm256_castpd_si256(_mm256_or_pd(_mm256_castsi256_pd(equal), + _mm256_castsi256_pd(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi64(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return avx2_emu_max(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu64(mem, mask, x); + } + static int32_t double_compressstore(void *left_addr, + void *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore64( + left_addr, right_addr, k, reg); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi64((const long long int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), + _mm256_castsi256_pd(y), + _mm256_castsi256_pd(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi64((long long int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return avx2_emu_min(x, y); + } + template + static reg_t permutexvar(reg_t ymm) + { + return _mm256_permute4x64_epi64(ymm, idx); + } + template + static reg_t permutevar(reg_t ymm) + { + return _mm256_permute4x64_epi64(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const int32_t rev_index = SHUFFLE_MASK(0, 1, 2, 3); + return permutexvar(ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi64(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max64(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min64(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi64x(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_castpd_si256( + _mm256_permute_pd(_mm256_castsi256_pd(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_64bit>(x); + } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } +}; +template <> +struct avx2_vector { + using type_t = double; + using reg_t = __m256d; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 4; + static constexpr int network_sort_threshold = 64; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_64bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITY; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITY; + } + static reg_t zmm_max() + { + return _mm256_set1_pd(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_64bit(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return convert_avx2_mask_to_int_64bit(mask); + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)) { + return _mm256_castpd_si256(_mm256_cmp_pd(x, x, _CMP_UNORD_Q)); + } + else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + static ymmi_t seti(int v1, int v2, int v3, int v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } + + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_pd((const double *)mem, mask); + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm256_castpd_si256(_mm256_cmp_pd(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_castpd_si256(_mm256_cmp_pd(x, y, _CMP_EQ_OQ)); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_pd( + src, base, index, _mm256_castsi256_pd(mask), scale); + ; + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i64gather_pd((double *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_pd((double const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_pd(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu64(mem, mask, x); + } + static int32_t double_compressstore(void *left_addr, + void *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore64( + left_addr, right_addr, k, reg); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_pd((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_blendv_pd(x, y, _mm256_castsi256_pd(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_pd((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_pd(x, y); + } + template + static reg_t permutexvar(reg_t ymm) + { + return _mm256_permute4x64_pd(ymm, idx); + } + template + static reg_t permutevar(reg_t ymm) + { + return _mm256_permute4x64_pd(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const int32_t rev_index = SHUFFLE_MASK(0, 1, 2, 3); + return permutexvar(ymm); + } + template + static type_t extract(reg_t v) + { + int64_t x = _mm256_extract_epi64(_mm256_castpd_si256(v), index); + double y; + std::memcpy(&y, &x, sizeof(y)); + return y; + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max64(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min64(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_pd(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_permute_pd(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_pd((double *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_64bit>(x); + } + static reg_t cast_from(__m256i v) + { + return _mm256_castsi256_pd(v); + } + static __m256i cast_to(reg_t v) + { + return _mm256_castpd_si256(v); + } +}; + +struct avx2_64bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + v = _mm256_permute4x64_epi64(v, 0b10110001); + } + else if constexpr (scale == 4) { + v = _mm256_permute4x64_epi64(v, 0b01001110); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m256d v1 = _mm256_castsi256_pd(vtype::cast_to(reg)); + __m256d v2 = _mm256_castsi256_pd(vtype::cast_to(other)); + + if constexpr (scale == 2) { v1 = _mm256_blend_pd(v1, v2, 0b0101); } + else if constexpr (scale == 4) { + v1 = _mm256_blend_pd(v1, v2, 0b0011); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(_mm256_castpd_si256(v1)); + } +}; + +#endif // AVX2_QSORT_32BIT diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 43eed316..0dd50c09 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -20,6 +20,21 @@ constexpr auto avx2_mask_helper_lut32 = [] { return lut; }(); +constexpr auto avx2_mask_helper_lut64 = [] { + std::array, 16> lut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array entry {}; + for (int j = 0; j < 4; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFFFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + constexpr auto avx2_compressstore_lut32_gen = [] { std::array, 256>, 2> lutPair {}; auto &permLut = lutPair[0]; @@ -50,6 +65,38 @@ constexpr auto avx2_compressstore_lut32_gen = [] { constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; +constexpr auto avx2_compressstore_lut64_gen = [] { + std::array, 16> permLut {}; + std::array, 16> leftLut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0}; + int right = 7; + int left = 0; + for (int j = 0; j < 4; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = 2 * j + 1; + indices[right - 1] = 2 * j; + right -= 2; + } + else { + indices[left + 1] = 2 * j + 1; + indices[left] = 2 * j; + leftEntry[left / 2] = 0xFFFFFFFFFFFFFFFF; + left += 2; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return std::make_pair(permLut, leftLut); +}(); +constexpr auto avx2_compressstore_lut64_perm + = avx2_compressstore_lut64_gen.first; +constexpr auto avx2_compressstore_lut64_left + = avx2_compressstore_lut64_gen.second; + X86_SIMD_SORT_INLINE __m256i convert_int_to_avx2_mask(int32_t m) { @@ -63,6 +110,19 @@ int32_t convert_avx2_mask_to_int(__m256i m) return _mm256_movemask_ps(_mm256_castsi256_ps(m)); } +X86_SIMD_SORT_INLINE +__m256i convert_int_to_avx2_mask_64bit(int32_t m) +{ + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut64[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int_64bit(__m256i m) +{ + return _mm256_movemask_pd(_mm256_castsi256_pd(m)); +} + // Emulators for intrinsics missing from AVX2 compared to AVX512 template T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) @@ -95,9 +155,31 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) } template -void avx2_emu_mask_compressstoreu(void *base_addr, - typename avx2_vector::opmask_t k, - typename avx2_vector::reg_t reg) +T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + typename vtype::reg_t inter1 = vtype::max( + x, vtype::template permutexvar(x)); + T can1 = vtype::template extract<0>(inter1); + T can2 = vtype::template extract<2>(inter1); + return std::max(can1, can2); +} + +template +T avx2_emu_reduce_min64(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + typename vtype::reg_t inter1 = vtype::min( + x, vtype::template permutexvar(x)); + T can1 = vtype::template extract<0>(inter1); + T can2 = vtype::template extract<2>(inter1); + return std::min(can1, can2); +} + +template +void avx2_emu_mask_compressstoreu32(void *base_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { using vtype = avx2_vector; @@ -114,6 +196,27 @@ void avx2_emu_mask_compressstoreu(void *base_addr, vtype::mask_storeu(leftStore, left, temp); } +template +void avx2_emu_mask_compressstoreu64(void *base_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int_64bit(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::cast_from( + _mm256_permutevar8x32_epi32(vtype::cast_to(reg), perm)); + + vtype::mask_storeu(leftStore, left, temp); +} + template int avx2_double_compressstore32(void *left_addr, void *right_addr, @@ -139,6 +242,32 @@ int avx2_double_compressstore32(void *left_addr, return _mm_popcnt_u32(shortMask); } +template +int32_t avx2_double_compressstore64(void *left_addr, + void *right_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int_64bit(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::cast_from( + _mm256_permutevar8x32_epi32(vtype::cast_to(reg), perm)); + + vtype::mask_storeu(leftStore, left, temp); + vtype::mask_storeu(rightStore, ~left, temp); + + return _mm_popcnt_u32(shortMask); +} + template typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, typename avx2_vector::reg_t y)