Skip to content

Commit

Permalink
make constraints on Kokkos::sort more visible/clear (kokkos#6234)
Browse files Browse the repository at this point in the history
  • Loading branch information
fnrizzi committed Jun 23, 2023
1 parent 9e6befc commit 5b6cb80
Showing 1 changed file with 64 additions and 21 deletions.
85 changes: 64 additions & 21 deletions algorithms/src/sorting/Kokkos_SortPublicAPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,32 @@

namespace Kokkos {

// ---------------------------------------------------------------
// basic overloads
// ---------------------------------------------------------------

// clang-format off
template <class ExecutionSpace, class DataType, class... Properties>
std::enable_if_t<(Kokkos::is_execution_space<ExecutionSpace>::value) &&
(!SpaceAccessibility<
HostSpace, typename Kokkos::View<DataType, Properties...>::
memory_space>::accessible)>
std::enable_if_t<
(Kokkos::is_execution_space<ExecutionSpace>::value) &&
(!SpaceAccessibility<
HostSpace, typename Kokkos::View<DataType, Properties...>::memory_space
>::accessible)
>
// clang-format on
sort(const ExecutionSpace& exec,
const Kokkos::View<DataType, Properties...>& view) {
// Although we are using BinSort below, which could work on rank-2 views,
// for now view must be rank-1 because the Impl::min_max_functor
// used below only works for rank-1 views
using ViewType = Kokkos::View<DataType, Properties...>;
static_assert(ViewType::rank == 1,
"Kokkos::sort: currently only supports rank-1 Views.");

if (view.extent(0) == 0) {
return;
}

using ViewType = Kokkos::View<DataType, Properties...>;
using CompType = BinOp1D<ViewType>;

Kokkos::MinMaxScalar<typename ViewType::non_const_value_type> result;
Kokkos::MinMax<typename ViewType::non_const_value_type> reducer(result);
parallel_reduce("Kokkos::Sort::FindExtent",
Expand Down Expand Up @@ -114,6 +126,7 @@ sort(const ExecutionSpace& exec,
static_cast<double>(result.min_val)));
}

using CompType = BinOp1D<ViewType>;
BinSort<ViewType, CompType> bin_sort(
view, CompType(max_bins, result.min_val, result.max_val), sort_in_bins);
bin_sort.create_permute_vector(exec);
Expand All @@ -124,37 +137,45 @@ sort(const ExecutionSpace& exec,
template <class DataType, class... Properties>
void sort(const Experimental::SYCL& space,
const Kokkos::View<DataType, Properties...>& view) {
if (view.extent(0) == 0) {
return;
}

using ViewType = Kokkos::View<DataType, Properties...>;
static_assert(SpaceAccessibility<Experimental::SYCL,
typename ViewType::memory_space>::accessible,
"SYCL execution space is not able to access the memory space "
"of the View argument!");

auto queue = space.sycl_queue();
auto policy = oneapi::dpl::execution::make_device_policy(queue);

// Can't use Experimental::begin/end here since the oneDPL then assumes that
// the data is on the host.
static_assert(
ViewType::rank == 1 &&
(std::is_same<typename ViewType::array_layout, LayoutRight>::value ||
std::is_same<typename ViewType::array_layout, LayoutLeft>::value),
"SYCL sort only supports contiguous 1D Views.");
"SYCL sort only supports contiguous rank-1 Views.");

if (view.extent(0) == 0) {
return;
}

auto queue = space.sycl_queue();
auto policy = oneapi::dpl::execution::make_device_policy(queue);
const int n = view.extent(0);
oneapi::dpl::sort(policy, view.data(), view.data() + n);
}
#endif

// clang-format off
template <class ExecutionSpace, class DataType, class... Properties>
std::enable_if_t<(Kokkos::is_execution_space<ExecutionSpace>::value) &&
(SpaceAccessibility<
HostSpace, typename Kokkos::View<DataType, Properties...>::
memory_space>::accessible)>
std::enable_if_t<
(Kokkos::is_execution_space<ExecutionSpace>::value) &&
(SpaceAccessibility<
HostSpace, typename Kokkos::View<DataType, Properties...>::memory_space
>::accessible)
>
// clang-format on
sort(const ExecutionSpace&, const Kokkos::View<DataType, Properties...>& view) {
using ViewType = Kokkos::View<DataType, Properties...>;
static_assert(ViewType::rank == 1,
"Kokkos::sort: currently only supports rank-1 Views.");

if (view.extent(0) == 0) {
return;
}
Expand All @@ -167,6 +188,10 @@ sort(const ExecutionSpace&, const Kokkos::View<DataType, Properties...>& view) {
template <class DataType, class... Properties>
void sort(const Cuda& space,
const Kokkos::View<DataType, Properties...>& view) {
using ViewType = Kokkos::View<DataType, Properties...>;
static_assert(ViewType::rank == 1,
"Kokkos::sort: currently only supports rank-1 Views.");

if (view.extent(0) == 0) {
return;
}
Expand All @@ -177,8 +202,12 @@ void sort(const Cuda& space,
}
#endif

template <class ViewType>
void sort(ViewType const& view) {
template <class DataType, class... Properties>
void sort(const Kokkos::View<DataType, Properties...>& view) {
using ViewType = Kokkos::View<DataType, Properties...>;
static_assert(ViewType::rank == 1,
"Kokkos::sort: currently only supports rank-1 Views.");

Kokkos::fence("Kokkos::sort: before");

if (view.extent(0) == 0) {
Expand All @@ -190,10 +219,20 @@ void sort(ViewType const& view) {
exec.fence("Kokkos::sort: fence after sorting");
}

// ---------------------------------------------------------------
// overloads for sorting a view with a subrange
// specified via integers begin, end
// ---------------------------------------------------------------

template <class ExecutionSpace, class ViewType>
std::enable_if_t<Kokkos::is_execution_space<ExecutionSpace>::value> sort(
const ExecutionSpace& exec, ViewType view, size_t const begin,
size_t const end) {
// view must be rank-1 because the Impl::min_max_functor
// used below only works for rank-1 views for now
static_assert(ViewType::rank == 1,
"Kokkos::sort: currently only supports rank-1 Views.");

if (view.extent(0) == 0) {
return;
}
Expand All @@ -219,6 +258,10 @@ std::enable_if_t<Kokkos::is_execution_space<ExecutionSpace>::value> sort(

template <class ViewType>
void sort(ViewType view, size_t const begin, size_t const end) {
// same constraints as the overload above which this gets dispatched to
static_assert(ViewType::rank == 1,
"Kokkos::sort: currently only supports rank-1 Views.");

Kokkos::fence("Kokkos::sort: before");

if (view.extent(0) == 0) {
Expand Down

0 comments on commit 5b6cb80

Please sign in to comment.