Skip to content

Commit

Permalink
change impl of is_sorted_until to use reduce (kokkos#6097)
Browse files Browse the repository at this point in the history
* change impl of is_sorted_until to use reduce

* address comments
  • Loading branch information
fnrizzi committed May 4, 2023
1 parent d6944df commit 26ae798
Showing 1 changed file with 36 additions and 48 deletions.
84 changes: 36 additions & 48 deletions algorithms/src/std_algorithms/impl/Kokkos_IsSortedUntil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,30 @@ namespace Kokkos {
namespace Experimental {
namespace Impl {

template <class IteratorType, class IndicatorViewType, class ComparatorType>
template <class IteratorType, class ComparatorType, class ReducerType>
struct StdIsSortedUntilFunctor {
using index_type = typename IteratorType::difference_type;
using value_type = typename ReducerType::value_type;

IteratorType m_first;
IndicatorViewType m_indicator;
ComparatorType m_comparator;
ReducerType m_reducer;

KOKKOS_FUNCTION
void operator()(const index_type i, int& update, const bool final) const {
void operator()(const index_type i, value_type& reduction_result) const {
const auto& val_i = m_first[i];
const auto& val_ip1 = m_first[i + 1];

if (m_comparator(val_ip1, val_i)) {
++update;
}

if (final) {
m_indicator(i) = update;
m_reducer.join(reduction_result, i);
}
}

KOKKOS_FUNCTION
StdIsSortedUntilFunctor(IteratorType _first1, IndicatorViewType indicator,
ComparatorType comparator)
: m_first(std::move(_first1)),
m_indicator(std::move(indicator)),
m_comparator(std::move(comparator)) {}
StdIsSortedUntilFunctor(IteratorType first, ComparatorType comparator,
ReducerType reducer)
: m_first(std::move(first)),
m_comparator(std::move(comparator)),
m_reducer(std::move(reducer)) {}
};

template <class ExecutionSpace, class IteratorType, class ComparatorType>
Expand All @@ -73,40 +70,31 @@ IteratorType is_sorted_until_impl(const std::string& label,
}

/*
use scan and a helper "indicator" view
such that we scan the data and fill the indicator with
partial sum that is always 0 unless we find a pair that
breaks the sorting, so in that case the indicator will
have a 1 starting at the location where the sorting breaks.
So finding that 1 means finding the location we want.
*/

// aliases
using indicator_value_type = std::size_t;
using indicator_view_type =
::Kokkos::View<indicator_value_type*, ExecutionSpace>;
using functor_type =
StdIsSortedUntilFunctor<IteratorType, indicator_view_type,
ComparatorType>;

// do scan
// use num_elements-1 because each index handles i and i+1
const auto num_elements_minus_one = num_elements - 1;
indicator_view_type indicator("is_sorted_until_indicator_helper",
num_elements_minus_one);
::Kokkos::parallel_scan(
label, RangePolicy<ExecutionSpace>(ex, 0, num_elements_minus_one),
functor_type(first, indicator, std::move(comp)));

// try to find the first sentinel value, which indicates
// where the sorting condition breaks
namespace KE = ::Kokkos::Experimental;
constexpr indicator_value_type sentinel_value = 1;
auto r =
KE::find(ex, KE::cbegin(indicator), KE::cend(indicator), sentinel_value);
const auto shift = r - ::Kokkos::Experimental::cbegin(indicator);

return first + (shift + 1);
Do a par_reduce computing the *min* index that breaks the sorting.
If such an index is found, then the range is sorted until that element.
If no such index is found, then the range is sorted until the end.
*/
using index_type = typename IteratorType::difference_type;
index_type reduction_result;
::Kokkos::Min<index_type> reducer(reduction_result);
::Kokkos::parallel_reduce(
label,
// use num_elements-1 because each index handles i and i+1
RangePolicy<ExecutionSpace>(ex, 0, num_elements - 1),
// use CTAD
StdIsSortedUntilFunctor(first, comp, reducer), reducer);

/* If the reduction result is equal to the initial value,
it means the range is sorted until the end */
index_type reduction_result_init;
reducer.init(reduction_result_init);
if (reduction_result == reduction_result_init) {
return last;
} else {
/* If such an index is found, then the range is sorted until there and
we need to return an iterator past the element found so do +1 */
return first + (reduction_result + 1);
}
}

template <class ExecutionSpace, class IteratorType>
Expand Down

0 comments on commit 26ae798

Please sign in to comment.