Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions src/avx512-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,6 @@ struct zmm_vector<uint16_t> {
}
};

template <>
bool is_a_nan<uint16_t>(uint16_t elem)
{
return (elem & 0x7c00) == 0x7c00;
}

template <>
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
{
Expand Down Expand Up @@ -383,6 +377,34 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
//return npy_half_to_float(a) < npy_half_to_float(b);
}

X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
int64_t arrsize)
{
int64_t nan_count = 0;
__mmask16 loadmask = 0xFFFF;
while (arrsize > 0) {
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
__m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr);
__m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm);
__mmask16 nanmask = _mm512_cmp_ps_mask(
in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
nan_count += _mm_popcnt_u32((int32_t)nanmask);
_mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF);
arr += 16;
arrsize -= 16;
}
return nan_count;
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
{
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
arr[ii] = 0xFFFF;
nan_count -= 1;
}
}

template <>
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
{
Expand All @@ -403,10 +425,11 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)

void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem >= k) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qselect_16bit_<zmm_vector<float16>, uint16_t>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

Expand All @@ -430,10 +453,11 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)

void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem > 0) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qsort_16bit_<zmm_vector<float16>, uint16_t>(
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

Expand Down
39 changes: 33 additions & 6 deletions src/avx512-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,31 @@ static void qselect_32bit_(type_t *arr,
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
}

X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
{
int64_t nan_count = 0;
__mmask16 loadmask = 0xFFFF;
while (arrsize > 0) {
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
nan_count += _mm_popcnt_u32((int32_t)nanmask);
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
arr += 16;
arrsize -= 16;
}
return nan_count;
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
{
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
arr[ii] = std::nanf("1");
nan_count -= 1;
}
}

template <>
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
{
Expand All @@ -710,10 +735,11 @@ void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
template <>
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem >= k) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qselect_32bit_<zmm_vector<float>, float>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

Expand All @@ -738,10 +764,11 @@ void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
template <>
void avx512_qsort<float>(float *arr, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem > 0) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qsort_32bit_<zmm_vector<float>, float>(
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

Expand Down
14 changes: 8 additions & 6 deletions src/avx512-64bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,10 +804,11 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
template <>
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem >= k) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qselect_64bit_<zmm_vector<double>, double>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

Expand All @@ -832,10 +833,11 @@ void avx512_qsort<uint64_t>(uint64_t *arr, int64_t arrsize)
template <>
void avx512_qsort<double>(double *arr, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem > 0) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qsort_64bit_<zmm_vector<double>, double>(
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}
#endif // AVX512_QSORT_64BIT
29 changes: 0 additions & 29 deletions src/avx512-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,35 +116,6 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
template <typename T>
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);

template <typename T>
bool is_a_nan(T elem)
{
return std::isnan(elem);
}

/*
* Sort all the NAN's to end of the array and return the index of the last elem
* in the array which is not a nan
*/
template <typename T>
int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize)
{
int64_t jj = arrsize - 1;
int64_t ii = 0;
int64_t count = 0;
while (ii <= jj) {
if (is_a_nan(arr[ii])) {
std::swap(arr[ii], arr[jj]);
jj -= 1;
count++;
}
else {
ii += 1;
}
}
return arrsize-count-1;
}

template <typename vtype, typename T = typename vtype::type_t>
bool comparison_func(const T &a, const T &b)
{
Expand Down
46 changes: 35 additions & 11 deletions src/avx512fp16-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,31 +114,55 @@ struct zmm_vector<_Float16> {
}
};

template <>
bool is_a_nan<_Float16>(_Float16 elem)
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr,
int64_t arrsize)
{
int64_t nan_count = 0;
__mmask32 loadmask = 0xFFFFFFFF;
__m512h in_zmm;
while (arrsize > 0) {
if (arrsize < 32) {
loadmask = (0x00000001 << arrsize) - 0x00000001;
in_zmm = _mm512_castsi512_ph(
_mm512_maskz_loadu_epi16(loadmask, arr));
}
else {
in_zmm = _mm512_loadu_ph(arr);
}
__mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
nan_count += _mm_popcnt_u32((int32_t)nanmask);
_mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF);
arr += 32;
arrsize -= 32;
}
return nan_count;
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
{
Fp16Bits temp;
temp.f_ = elem;
return (temp.i_ & 0x7c00) == 0x7c00;
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
}

template <>
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem >= k) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}

template <>
void avx512_qsort(_Float16 *arr, int64_t arrsize)
{
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
if (indx_last_elem > 0) {
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
qsort_16bit_<zmm_vector<_Float16>, _Float16>(
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
}
}
#endif // AVX512FP16_QSORT_16BIT