Skip to content

Commit

Permalink
BUG: fix dispatch for avx512_qsort and qselect float16 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
r-devulap committed Dec 13, 2023
1 parent 3889327 commit adb9a59
Showing 1 changed file with 3 additions and 44 deletions.
47 changes: 3 additions & 44 deletions numpy/_core/src/npysort/x86_simd_qsort_16bit.dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,9 @@

#if defined(NPY_HAVE_AVX512_SPR)
#include "x86-simd-sort/src/avx512fp16-16bit-qsort.hpp"
/*
* Wrapper function declarations to avoid multiple definitions of
* avx512_qsort<uint16_t> and avx512_qsort<int16_t>
*/
void avx512_qsort_uint16(uint16_t*, npy_intp);
void avx512_qsort_int16(int16_t*, npy_intp);
void avx512_qselect_uint16(uint16_t*, npy_intp, npy_intp);
void avx512_qselect_int16(int16_t*, npy_intp, npy_intp);

#include "x86-simd-sort/src/avx512-16bit-qsort.hpp"
#elif defined(NPY_HAVE_AVX512_ICL)
#include "x86-simd-sort/src/avx512-16bit-qsort.hpp"
/* Wrapper function defintions here: */
void avx512_qsort_uint16(uint16_t* arr, npy_intp size)
{
avx512_qsort(arr, size);
}
void avx512_qsort_int16(int16_t* arr, npy_intp size)
{
avx512_qsort(arr, size);
}
void avx512_qselect_uint16(uint16_t* arr, npy_intp kth, npy_intp size)
{
avx512_qselect(arr, kth, size, true);
}
void avx512_qselect_int16(int16_t* arr, npy_intp kth, npy_intp size)
{
avx512_qselect(arr, kth, size, true);
}
#endif

namespace np { namespace qsort_simd {
Expand All @@ -50,20 +25,12 @@ template<> void NPY_CPU_DISPATCH_CURFX(QSelect)(Half *arr, npy_intp num, npy_int

template<> void NPY_CPU_DISPATCH_CURFX(QSelect)(uint16_t *arr, npy_intp num, npy_intp kth)
{
#if defined(NPY_HAVE_AVX512_SPR)
avx512_qselect_uint16(arr, kth, num);
#else
avx512_qselect(arr, kth, num);
#endif
avx512_qselect(arr, num, kth);
}

template<> void NPY_CPU_DISPATCH_CURFX(QSelect)(int16_t *arr, npy_intp num, npy_intp kth)
{
#if defined(NPY_HAVE_AVX512_SPR)
avx512_qselect_int16(arr, kth, num);
#else
avx512_qselect(arr, kth, num);
#endif
avx512_qselect(arr, num, kth);
}

/*
Expand All @@ -79,19 +46,11 @@ template<> void NPY_CPU_DISPATCH_CURFX(QSort)(Half *arr, npy_intp size)
}
template<> void NPY_CPU_DISPATCH_CURFX(QSort)(uint16_t *arr, npy_intp size)
{
#if defined(NPY_HAVE_AVX512_SPR)
avx512_qsort_uint16(arr, size);
#else
avx512_qsort(arr, size);
#endif
}
template<> void NPY_CPU_DISPATCH_CURFX(QSort)(int16_t *arr, npy_intp size)
{
#if defined(NPY_HAVE_AVX512_SPR)
avx512_qsort_int16(arr, size);
#else
avx512_qsort(arr, size);
#endif
}
#endif // NPY_HAVE_AVX512_ICL || SPR

Expand Down

0 comments on commit adb9a59

Please sign in to comment.