Skip to content

Commit

Permalink
Add HIP specialization for sort-by-key
Browse files Browse the repository at this point in the history
  • Loading branch information
Rombur committed Mar 11, 2024
1 parent 35ad698 commit e5126e9
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@

#endif

#if defined(KOKKOS_ENABLE_ROCTHRUST)
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#endif

#if defined(KOKKOS_ENABLE_ONEDPL) && \
(ONEDPL_VERSION_MAJOR > 2022 || \
(ONEDPL_VERSION_MAJOR == 2022 && ONEDPL_VERSION_MINOR >= 2))
Expand Down Expand Up @@ -117,6 +122,26 @@ void sort_by_key_cudathrust(
}
#endif

#if defined(KOKKOS_ENABLE_ROCTHRUST)
template <class Layout>
inline constexpr bool sort_on_device_v<Kokkos::HIP, Layout> = true;

template <class KeysDataType, class... KeysProperties, class ValuesDataType,
class... ValuesProperties, class... MaybeComparator>
void sort_by_key_rocthrust(
const Kokkos::HIP& exec,
const Kokkos::View<KeysDataType, KeysProperties...>& keys,
const Kokkos::View<ValuesDataType, ValuesProperties...>& values,
MaybeComparator&&... maybeComparator) {
const auto policy = thrust::hip::par.on(exec.hip_stream());
auto keys_first = ::Kokkos::Experimental::begin(keys);
auto keys_last = ::Kokkos::Experimental::end(keys);
auto values_first = ::Kokkos::Experimental::begin(values);
thrust::sort_by_key(policy, keys_first, keys_last, values_first,
std::forward<MaybeComparator>(maybeComparator)...);
}
#endif

#if defined(KOKKOS_ENABLE_ONEDPL)
template <class Layout>
inline constexpr bool sort_on_device_v<Kokkos::Experimental::SYCL, Layout> =
Expand Down Expand Up @@ -272,6 +297,17 @@ void sort_by_key_device_view_without_comparator(
}
#endif

#if defined(KOKKOS_ENABLE_ROCTHRUST)
template <class KeysDataType, class... KeysProperties, class ValuesDataType,
class... ValuesProperties>
void sort_by_key_device_view_without_comparator(
const Kokkos::HIP& exec,
const Kokkos::View<KeysDataType, KeysProperties...>& keys,
const Kokkos::View<ValuesDataType, ValuesProperties...>& values) {
sort_by_key_rocthrust(exec, keys, values);
}
#endif

#if defined(KOKKOS_ENABLE_ONEDPL)
template <class KeysDataType, class... KeysProperties, class ValuesDataType,
class... ValuesProperties>
Expand Down Expand Up @@ -317,6 +353,18 @@ void sort_by_key_device_view_with_comparator(
}
#endif

#if defined(KOKKOS_ENABLE_ROCTHRUST)
template <class ComparatorType, class KeysDataType, class... KeysProperties,
class ValuesDataType, class... ValuesProperties>
void sort_by_key_device_view_with_comparator(
const Kokkos::HIP& exec,
const Kokkos::View<KeysDataType, KeysProperties...>& keys,
const Kokkos::View<ValuesDataType, ValuesProperties...>& values,
const ComparatorType& comparator) {
sort_by_key_rocthrust(exec, keys, values, comparator);
}
#endif

#if defined(KOKKOS_ENABLE_ONEDPL)
template <class ComparatorType, class KeysDataType, class... KeysProperties,
class ValuesDataType, class... ValuesProperties>
Expand Down

0 comments on commit e5126e9

Please sign in to comment.