Skip to content

Commit

Permalink
Fix fallback implementation for sort_by_key (kokkos#6856)
Browse files Browse the repository at this point in the history
* Fix fallback implementation for sort_by_key

* Guard with KOKKOS_ENABLE_ONEDPL

* Drop sort_on_device

* Improve wording

* Improve comment
  • Loading branch information
masterleinad committed Mar 5, 2024
1 parent 99c7e1b commit 9feb104
Showing 1 changed file with 80 additions and 39 deletions.
119 changes: 80 additions & 39 deletions algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,20 @@ static_assert_is_admissible_to_kokkos_sort_by_key(const ViewType& /* view */) {
"LayoutRight, LayoutLeft or LayoutStride.");
}

// For the fallback implementation for sort_by_key using Kokkos::sort, we need
// to consider if Kokkos::sort defers to the fallback implementation that copies
// the array to the host and uses std::sort, see
// copy_to_host_run_stdsort_copy_back() in impl/Kokkos_SortImpl.hpp. If
// sort_on_device_v is true, we assume that std::sort doesn't copy data.
// Otherwise, we manually copy all data to the host and provide Kokkos::sort
// with a host execution space.
template <class ExecutionSpace, class Layout>
inline constexpr bool sort_on_device_v = false;

#if defined(KOKKOS_ENABLE_CUDA)
template <class Layout>
inline constexpr bool sort_on_device_v<Kokkos::Cuda, Layout> = true;

template <class KeysDataType, class... KeysProperties, class ValuesDataType,
class... ValuesProperties, class... MaybeComparator>
void sort_by_key_cudathrust(
Expand All @@ -104,6 +117,12 @@ void sort_by_key_cudathrust(
}
#endif

#if defined(KOKKOS_ENABLE_ONEDPL)
template <class Layout>
inline constexpr bool sort_on_device_v<Kokkos::Experimental::SYCL, Layout> =
std::is_same_v<Layout, Kokkos::LayoutLeft> ||
std::is_same_v<Layout, Kokkos::LayoutRight>;

#ifdef KOKKOS_ONEDPL_HAS_SORT_BY_KEY
template <class KeysDataType, class... KeysProperties, class ValuesDataType,
class... ValuesProperties, class... MaybeComparator>
Expand All @@ -126,6 +145,7 @@ void sort_by_key_onedpl(
std::forward<MaybeComparator>(maybeComparator)...);
}
#endif
#endif

template <typename ExecutionSpace, typename PermutationView, typename ViewType>
void applyPermutation(const ExecutionSpace& space,
Expand All @@ -152,6 +172,8 @@ void sort_by_key_via_sort(
const Kokkos::View<KeysDataType, KeysProperties...>& keys,
const Kokkos::View<ValuesDataType, ValuesProperties...>& values,
MaybeComparator&&... maybeComparator) {
static_assert(sizeof...(MaybeComparator) <= 1);

auto const n = keys.size();

Kokkos::View<unsigned int*, ExecutionSpace> permute(
Expand All @@ -165,48 +187,67 @@ void sort_by_key_via_sort(
Kokkos::RangePolicy<ExecutionSpace>(exec, 0, n),
KOKKOS_LAMBDA(int i) { permute(i) = i; });

// FIXME OPENMPTARGET The sort happens on the host so we have to copy keys there
#ifdef KOKKOS_ENABLE_OPENMPTARGET
auto keys_in_comparator = Kokkos::create_mirror_view(
Kokkos::view_alloc(Kokkos::HostSpace{}, Kokkos::WithoutInitializing),
keys);
Kokkos::deep_copy(exec, keys_in_comparator, keys);
#else
auto keys_in_comparator = keys;
#endif

static_assert(sizeof...(MaybeComparator) <= 1);
if constexpr (sizeof...(MaybeComparator) == 0) {
#ifdef KOKKOS_ENABLE_SYCL
auto* raw_keys_in_comparator = keys_in_comparator.data();
auto stride = keys_in_comparator.stride(0);
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return raw_keys_in_comparator[i * stride] <
raw_keys_in_comparator[j * stride];
});
#else
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return keys_in_comparator(i) < keys_in_comparator(j);
});
#endif
using Layout =
typename Kokkos::View<unsigned int*, ExecutionSpace>::array_layout;
if constexpr (!sort_on_device_v<ExecutionSpace, Layout>) {
auto host_keys = Kokkos::create_mirror_view(
Kokkos::view_alloc(Kokkos::HostSpace{}, Kokkos::WithoutInitializing),
keys);
auto host_permute = Kokkos::create_mirror_view(
Kokkos::view_alloc(Kokkos::HostSpace{}, Kokkos::WithoutInitializing),
permute);
Kokkos::deep_copy(exec, host_keys, keys);
Kokkos::deep_copy(exec, host_permute, permute);

exec.fence("Kokkos::Impl::sort_by_key_via_sort: before host sort");
Kokkos::DefaultHostExecutionSpace host_exec;

if constexpr (sizeof...(MaybeComparator) == 0) {
Kokkos::sort(
host_exec, host_permute,
KOKKOS_LAMBDA(int i, int j) { return host_keys(i) < host_keys(j); });
} else {
auto keys_comparator =
std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
Kokkos::sort(
host_exec, host_permute, KOKKOS_LAMBDA(int i, int j) {
return keys_comparator(host_keys(i), host_keys(j));
});
}
host_exec.fence("Kokkos::Impl::sort_by_key_via_sort: after host sort");
Kokkos::deep_copy(exec, permute, host_permute);
} else {
auto keys_comparator =
std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
#ifdef KOKKOS_ENABLE_SYCL
auto* raw_keys_in_comparator = keys_in_comparator.data();
auto stride = keys_in_comparator.stride(0);
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return keys_comparator(raw_keys_in_comparator[i * stride],
raw_keys_in_comparator[j * stride]);
});
auto* raw_keys_in_comparator = keys.data();
auto stride = keys.stride(0);
if constexpr (sizeof...(MaybeComparator) == 0) {
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return raw_keys_in_comparator[i * stride] <
raw_keys_in_comparator[j * stride];
});
} else {
auto keys_comparator =
std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return keys_comparator(raw_keys_in_comparator[i * stride],
raw_keys_in_comparator[j * stride]);
});
}
#else
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return keys_comparator(keys_in_comparator(i), keys_in_comparator(j));
});
if constexpr (sizeof...(MaybeComparator) == 0) {
Kokkos::sort(
exec, permute,
KOKKOS_LAMBDA(int i, int j) { return keys(i) < keys(j); });
} else {
auto keys_comparator =
std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
Kokkos::sort(
exec, permute, KOKKOS_LAMBDA(int i, int j) {
return keys_comparator(keys(i), keys(j));
});
}
#endif
}

Expand Down

0 comments on commit 9feb104

Please sign in to comment.