Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions meson.build
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
project('x86-simd-sort', 'cpp',
version : '1.0.0',
license : 'BSD 3-clause')
version : '2.0.0',
license : 'BSD 3-clause',
default_options : ['cpp_std=c++17'])
cpp = meson.get_compiler('cpp')
src = include_directories('src')
bench = include_directories('benchmarks')
Expand Down
80 changes: 35 additions & 45 deletions src/avx512-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
//return npy_half_to_float(a) < npy_half_to_float(b);
}

X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
template <>
int64_t replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr,
int64_t arrsize)
{
int64_t nan_count = 0;
Expand All @@ -396,77 +397,66 @@ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
return nan_count;
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
{
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
arr[ii] = 0xFFFF;
nan_count -= 1;
}
}

template <>
bool is_a_nan<uint16_t>(uint16_t elem)
{
return (elem & 0x7c00) == 0x7c00;
}

/* Specialized template function for 16-bit qsort_ funcs*/
template <>
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
void qsort_<zmm_vector<int16_t>>(int16_t *arr,
int64_t left,
int64_t right,
int64_t maxiters)
{
if (arrsize > 1) {
qselect_16bit_<zmm_vector<int16_t>, int16_t>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qsort_16bit_<zmm_vector<int16_t>>(arr, left, right, maxiters);
}

template <>
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
void qsort_<zmm_vector<uint16_t>>(uint16_t *arr,
int64_t left,
int64_t right,
int64_t maxiters)
{
if (arrsize > 1) {
qselect_16bit_<zmm_vector<uint16_t>, uint16_t>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qsort_16bit_<zmm_vector<uint16_t>>(arr, left, right, maxiters);
}

void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
{
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_16bit_<zmm_vector<float16>, uint16_t>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(
arr, arrsize);
qsort_16bit_<zmm_vector<float16>, uint16_t>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

/* Specialized template function for 16-bit qselect_ funcs*/
template <>
void avx512_qsort(int16_t *arr, int64_t arrsize)
void qselect_<zmm_vector<int16_t>>(
int16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
{
if (arrsize > 1) {
qsort_16bit_<zmm_vector<int16_t>, int16_t>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qselect_16bit_<zmm_vector<int16_t>>(arr, k, left, right, maxiters);
}

template <>
void avx512_qsort(uint16_t *arr, int64_t arrsize)
void qselect_<zmm_vector<uint16_t>>(
uint16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
{
if (arrsize > 1) {
qsort_16bit_<zmm_vector<uint16_t>, uint16_t>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qselect_16bit_<zmm_vector<uint16_t>>(arr, k, left, right, maxiters);
}

void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qsort_16bit_<zmm_vector<float16>, uint16_t>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_16bit_<zmm_vector<float16>, uint16_t>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
}

#endif // AVX512_QSORT_16BIT
107 changes: 39 additions & 68 deletions src/avx512-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ struct zmm_vector<float> {
{
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
}
static opmask_t get_partial_loadmask(int size)
{
return (0x0001 << size) - 0x0001;
}
template <int type>
static opmask_t fpclass(zmm_t x)
{
return _mm512_fpclass_ps_mask(x, type);
}
template <int scale>
static ymm_t i64gather(__m512i index, void const *base)
{
Expand All @@ -279,6 +288,10 @@ struct zmm_vector<float> {
{
return _mm512_mask_compressstoreu_ps(mem, mask, x);
}
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
{
return _mm512_maskz_loadu_ps(mask, mem);
}
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
{
return _mm512_mask_loadu_ps(x, mask, mem);
Expand Down Expand Up @@ -689,95 +702,53 @@ static void qselect_32bit_(type_t *arr,
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
}

X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
{
int64_t nan_count = 0;
__mmask16 loadmask = 0xFFFF;
while (arrsize > 0) {
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
nan_count += _mm_popcnt_u32((int32_t)nanmask);
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
arr += 16;
arrsize -= 16;
}
return nan_count;
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
{
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
arr[ii] = std::nanf("1");
nan_count -= 1;
}
}

/* Specialized template function for 32-bit qselect_ funcs*/
template <>
void avx512_qselect<int32_t>(int32_t *arr,
int64_t k,
int64_t arrsize,
bool hasnan)
void qselect_<zmm_vector<int32_t>>(
int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
{
if (arrsize > 1) {
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qselect_32bit_<zmm_vector<int32_t>>(arr, k, left, right, maxiters);
}

template <>
void avx512_qselect<uint32_t>(uint32_t *arr,
int64_t k,
int64_t arrsize,
bool hasnan)
void qselect_<zmm_vector<uint32_t>>(
uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
{
if (arrsize > 1) {
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qselect_32bit_<zmm_vector<uint32_t>>(arr, k, left, right, maxiters);
}

template <>
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
void qselect_<zmm_vector<float>>(
float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
{
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_32bit_<zmm_vector<float>, float>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
qselect_32bit_<zmm_vector<float>>(arr, k, left, right, maxiters);
}

/* Specialized template function for 32-bit qsort_ funcs*/
template <>
void avx512_qsort<int32_t>(int32_t *arr, int64_t arrsize)
void qsort_<zmm_vector<int32_t>>(int32_t *arr,
int64_t left,
int64_t right,
int64_t maxiters)
{
if (arrsize > 1) {
qsort_32bit_<zmm_vector<int32_t>, int32_t>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qsort_32bit_<zmm_vector<int32_t>>(arr, left, right, maxiters);
}

template <>
void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
void qsort_<zmm_vector<uint32_t>>(uint32_t *arr,
int64_t left,
int64_t right,
int64_t maxiters)
{
if (arrsize > 1) {
qsort_32bit_<zmm_vector<uint32_t>, uint32_t>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
qsort_32bit_<zmm_vector<uint32_t>>(arr, left, right, maxiters);
}

template <>
void avx512_qsort<float>(float *arr, int64_t arrsize)
void qsort_<zmm_vector<float>>(float *arr,
int64_t left,
int64_t right,
int64_t maxiters)
{
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qsort_32bit_<zmm_vector<float>, float>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
qsort_32bit_<zmm_vector<float>>(arr, left, right, maxiters);
}

#endif //AVX512_QSORT_32BIT
Loading