Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/avx512-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,13 @@ replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
}

template <>
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
bool is_a_nan<uint16_t>(uint16_t elem)
{
return (elem & 0x7c00) == 0x7c00;
}

template <>
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
qselect_16bit_<zmm_vector<int16_t>, int16_t>(
Expand All @@ -415,21 +421,23 @@ void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
}

template <>
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
qselect_16bit_<zmm_vector<uint16_t>, uint16_t>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
}
}

void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_16bit_<zmm_vector<float16>, uint16_t>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
}

Expand Down
16 changes: 9 additions & 7 deletions src/avx512-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
}

template <>
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
Expand All @@ -724,7 +724,7 @@ void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
}

template <>
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
Expand All @@ -733,13 +733,15 @@ void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
}

template <>
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize)
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_32bit_<zmm_vector<float>, float>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
}

Expand Down
16 changes: 9 additions & 7 deletions src/avx512-64bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ static void qselect_64bit_(type_t *arr,
}

template <>
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize)
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
qselect_64bit_<zmm_vector<int64_t>, int64_t>(
Expand All @@ -793,7 +793,7 @@ void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize)
}

template <>
void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
qselect_64bit_<zmm_vector<uint64_t>, uint64_t>(
Expand All @@ -802,13 +802,15 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
}

template <>
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize)
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan)
{
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_64bit_<zmm_vector<double>, double>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
}

Expand Down
44 changes: 38 additions & 6 deletions src/avx512-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
#define X86_SIMD_SORT_FINLINE static
#endif

#define LIKELY(x) __builtin_expect((x),1)
#define UNLIKELY(x) __builtin_expect((x),0)

template <typename type>
struct zmm_vector;

Expand All @@ -97,25 +100,54 @@ void avx512_qsort(T *arr, int64_t arrsize);
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);

template <typename T>
void avx512_qselect(T *arr, int64_t k, int64_t arrsize);
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize);
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);

template <typename T>
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize)
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
{
avx512_qselect<T>(arr, k - 1, arrsize);
avx512_qselect<T>(arr, k - 1, arrsize, hasnan);
avx512_qsort<T>(arr, k - 1);
}
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
{
avx512_qselect_fp16(arr, k - 1, arrsize);
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan);
avx512_qsort_fp16(arr, k - 1);
}

// key-value sort routines
template <typename T>
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);

template <typename T>
bool is_a_nan(T elem)
{
return std::isnan(elem);
}

/*
* Sort all the NAN's to end of the array and return the index of the last elem
* in the array which is not a nan
*/
template <typename T>
int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize)
{
int64_t jj = arrsize - 1;
int64_t ii = 0;
int64_t count = 0;
while (ii <= jj) {
if (is_a_nan(arr[ii])) {
std::swap(arr[ii], arr[jj]);
jj -= 1;
count++;
}
else {
ii += 1;
}
}
return arrsize-count-1;
}

template <typename vtype, typename T = typename vtype::type_t>
bool comparison_func(const T &a, const T &b)
{
Expand Down
20 changes: 15 additions & 5 deletions src/avx512fp16-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,23 @@ replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
}

template <>
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize)
bool is_a_nan<_Float16>(_Float16 elem)
{
if (arrsize > 1) {
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
Fp16Bits temp;
temp.f_ = elem;
return (temp.i_ & 0x7c00) == 0x7c00;
}

template <>
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
{
int64_t indx_last_elem = arrsize - 1;
if (UNLIKELY(hasnan)) {
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
}
if (indx_last_elem >= k) {
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
}

Expand Down