diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 387d8b57..5907d17a 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -95,6 +95,12 @@ struct ymm_vector { { return _mm512_mask_i64gather_ps(src, mask, index, base, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mmask_i32gather_ps(src, mask, index, base, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[7]], @@ -247,6 +253,12 @@ struct ymm_vector { { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[7]], @@ -393,6 +405,12 @@ struct ymm_vector { { return _mm512_mask_i64gather_epi32(src, mask, index, base, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[7]], @@ -548,6 +566,12 @@ struct zmm_vector { { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm512_mask_i32gather_epi64(src, mask, index, base, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[7]], @@ -688,6 +712,12 @@ struct zmm_vector { { return _mm512_mask_i64gather_epi64(src, mask, index, base, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm512_mask_i32gather_epi64(src, mask, index, base, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[7]], @@ -864,6 +894,12 @@ struct zmm_vector { { return _mm512_mask_i64gather_pd(src, mask, index, base, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm512_mask_i32gather_pd(src, mask, index, base, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[7]], diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index 357d143c..aa90c748 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -12,7 +12,9 @@ #include #include -using argtype = zmm_vector; +using argtype = typename std::conditional, + zmm_vector>::type; using argreg_t = typename argtype::reg_t; /* diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 7ecd1a13..fb2ef78c 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -103,7 +103,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort) for (auto type : this->arrtype) { for (auto size : this->arrsize) { // k should be at least 1 - size_t k = std::max(0x1ul, rand() % size); + size_t k = std::max((size_t)1, rand() % size); std::vector arr = get_array(type, size); std::vector sortedarr = arr; std::sort(sortedarr.begin(),