Skip to content

Commit

Permalink
kokkos#5635: SYCL: Add parallel_scan overload with return value
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable authored and cz4rs committed Sep 27, 2023
1 parent e4eb204 commit 7d817b8
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions core/src/SYCL/Kokkos_SYCL_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,28 +573,30 @@ parallel_reduce(const Impl::TeamThreadRangeBoundariesStruct<
* final == true.
*/
// This is the same code as in CUDA and largely the same as in OpenMPTarget
template <typename iType, typename FunctorType>
template <typename iType, typename FunctorType, typename ValueType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<iType, Impl::SYCLTeamMember>&
loop_bounds,
const FunctorType& lambda) {
// Extract value_type from lambda
using value_type = typename Kokkos::Impl::FunctorAnalysis<
const FunctorType& lambda, ValueType& return_val) {
// Extract ValueType from the Closure
using closure_value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;
static_assert(std::is_same<closure_value_type, ValueType>::value,
"Non-matching value types of closure and return type");

const auto start = loop_bounds.start;
const auto end = loop_bounds.end;
auto& member = loop_bounds.member;
const auto team_size = member.team_size();
const auto team_rank = member.team_rank();
const auto nchunk = (end - start + team_size - 1) / team_size;
value_type accum = 0;
ValueType accum = 0;
// each team has to process one or more chunks of the prefix scan
for (iType i = 0; i < nchunk; ++i) {
auto ii = start + i * team_size + team_rank;
// local accumulation for this chunk
value_type local_accum = 0;
ValueType local_accum = 0;
// user updates value with prefix value
if (ii < loop_bounds.end) lambda(ii, local_accum, false);
// perform team scan
Expand All @@ -608,6 +610,21 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
// broadcast last value to rest of the team
member.team_broadcast(accum, team_size - 1);
}

return_val = accum;
}

template <typename iType, class FunctorType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<iType, Impl::SYCLTeamMember>&
loop_bounds,
const FunctorType& lambda) {
using value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;

value_type scan_val;
parallel_scan(loop_bounds, lambda, scan_val);
}

template <typename iType, class Closure>
Expand Down

0 comments on commit 7d817b8

Please sign in to comment.