Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ int32_t, double, uint64_t, int64_t]`

## Key-value sort routines on pairs of arrays
```cpp
void x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan);
void x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan, bool descending);
void x86simdsort::keyvalue_select(T1* key, T2* val, size_t k, size_t size, bool hasnan, bool descending);
void x86simdsort::keyvalue_partial_sort(T1* key, T2* val, size_t k, size_t size, bool hasnan, bool descending);
```
Supported datatypes: `T1`, `T2` $\in$ `[float, uint32_t, int32_t, double,
uint64_t, int64_t]` Note that keyvalue sort is not yet supported for 16-bit
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/bench-keyvalue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ static void scalarkvsort(benchmark::State &state, Args &&...args)
std::vector<T> key_bkp = key;
// benchmark
for (auto _ : state) {
xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false);
xss::scalar::keyvalue_qsort(
key.data(), val.data(), arrsize, false, false);
state.PauseTiming();
key = key_bkp;
state.ResumeTiming();
Expand Down
54 changes: 32 additions & 22 deletions lib/x86simdsort-avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,48 @@
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
}

#define DEFINE_KEYVALUE_METHODS(type) \
template <> \
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
template <> \
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
void keyvalue_qsort(type1 *key, \
type2 *val, \
size_t arrsize, \
bool hasnan, \
bool descending) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
x86simdsortStatic::keyvalue_qsort( \
key, val, arrsize, hasnan, descending); \
} \
template <> \
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
void keyvalue_select(type1 *key, \
type2 *val, \
size_t k, \
size_t arrsize, \
bool hasnan, \
bool descending) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
x86simdsortStatic::keyvalue_select( \
key, val, k, arrsize, hasnan, descending); \
} \
template <> \
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
void keyvalue_partial_sort(type1 *key, \
type2 *val, \
size_t k, \
size_t arrsize, \
bool hasnan, \
bool descending) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
x86simdsortStatic::keyvalue_partial_sort( \
key, val, k, arrsize, hasnan, descending); \
}

#define DEFINE_KEYVALUE_METHODS(type) \
DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, double) \
DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, float)

namespace xss {
namespace avx2 {
DEFINE_ALL_METHODS(uint32_t)
Expand Down
69 changes: 63 additions & 6 deletions lib/x86simdsort-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,41 @@ namespace avx512 {
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
T2 *val,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void qselect(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value select
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// partial sort
template <typename T>
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value partial sort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
Expand All @@ -46,22 +65,41 @@ namespace avx2 {
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
T2 *val,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void qselect(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value select
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// partial sort
template <typename T>
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value partial sort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
Expand All @@ -80,22 +118,41 @@ namespace scalar {
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
T2 *val,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void qselect(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value select
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// partial sort
template <typename T>
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// key-value partial sort
template <typename T1, typename T2>
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
Expand Down
29 changes: 27 additions & 2 deletions lib/x86simdsort-scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,37 @@ namespace scalar {
return arg;
}
template <typename T1, typename T2>
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan)
void keyvalue_qsort(
T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending)
{
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
std::vector<size_t> arg = argsort(key, arrsize, hasnan, descending);
utils::apply_permutation_in_place(key, arg);
utils::apply_permutation_in_place(val, arg);
}
template <typename T1, typename T2>
void keyvalue_select(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan,
bool descending)
{
// Note that this does a full kv-sort
UNUSED(k);
keyvalue_qsort(key, val, arrsize, hasnan, descending);
}
template <typename T1, typename T2>
void keyvalue_partial_sort(T1 *key,
T2 *val,
size_t k,
size_t arrsize,
bool hasnan,
bool descending)
{
// Note that this does a full kv-sort
UNUSED(k);
keyvalue_qsort(key, val, arrsize, hasnan, descending);
}

} // namespace scalar
} // namespace xss
54 changes: 32 additions & 22 deletions lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,48 @@
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
}

#define DEFINE_KEYVALUE_METHODS(type) \
template <> \
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
} \
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
template <> \
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
void keyvalue_qsort(type1 *key, \
type2 *val, \
size_t arrsize, \
bool hasnan, \
bool descending) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
x86simdsortStatic::keyvalue_qsort( \
key, val, arrsize, hasnan, descending); \
} \
template <> \
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
void keyvalue_select(type1 *key, \
type2 *val, \
size_t k, \
size_t arrsize, \
bool hasnan, \
bool descending) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
x86simdsortStatic::keyvalue_select( \
key, val, k, arrsize, hasnan, descending); \
} \
template <> \
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
void keyvalue_partial_sort(type1 *key, \
type2 *val, \
size_t k, \
size_t arrsize, \
bool hasnan, \
bool descending) \
{ \
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
x86simdsortStatic::keyvalue_partial_sort( \
key, val, k, arrsize, hasnan, descending); \
}

#define DEFINE_KEYVALUE_METHODS(type) \
DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, double) \
DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \
DEFINE_KEYVALUE_METHODS_BASE(type, float)

namespace xss {
namespace avx512 {
DEFINE_ALL_METHODS(uint32_t)
Expand Down
Loading