Skip to content

Commit

Permalink
kokkos#5635: SYCL: Add parallel_scan overload with value for ThreadVe…
Browse files Browse the repository at this point in the history
…ctorRange
  • Loading branch information
thearusable authored and cz4rs committed Oct 6, 2023
1 parent e52b957 commit c5bf870
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
29 changes: 28 additions & 1 deletion core/src/SYCL/Kokkos_SYCL_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ parallel_scan(const Impl::ThreadVectorRangeBoundariesStruct<
// This sets i's val to i-1's contribution to make the latter shfl_up an
// exclusive scan -- the final accumulation of i's val will be included in
// the second closure call later.
if (i < loop_boundaries.end && tidx1 > 0) closure(i - 1, val, false);
if (i - 1 < loop_boundaries.end && tidx1 > 0) closure(i - 1, val, false);

// Bottom up exclusive scan in triangular pattern where each SYCL thread is
// the root of a reduction tree from the zeroth "lane" to itself.
Expand All @@ -847,6 +847,7 @@ parallel_scan(const Impl::ThreadVectorRangeBoundariesStruct<
if (i < loop_boundaries.end) closure(i, val, true);
accum = sg.shuffle(val, mask + vector_offset);
}
reducer.reference() = accum;
}

/** \brief Intra-thread vector parallel exclusive prefix sum.
Expand All @@ -869,6 +870,32 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
parallel_scan(loop_boundaries, closure, Kokkos::Sum<value_type>{dummy});
}

/** \brief Intra-thread vector parallel exclusive prefix sum.
*
* Executes closure(iType i, ValueType & val, bool final) for each i=[0..N)
*
* The range [0..N) is mapped to all vector lanes in the
* thread and a scan operation is performed.
* The last call to closure has final == true.
*/
template <typename iType, class Closure, typename ValueType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::ThreadVectorRangeBoundariesStruct<iType, Impl::SYCLTeamMember>&
loop_boundaries,
const Closure& closure, ValueType& return_val) {
// Extract ValueType from the Closure
using closure_value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, Closure,
void>::value_type;
static_assert(std::is_same<closure_value_type, ValueType>::value,
"Non-matching value types of closure and return type");

ValueType accum;
parallel_scan(loop_boundaries, closure, Kokkos::Sum<ValueType>{accum});

return_val = accum;
}

} // namespace Kokkos

namespace Kokkos {
Expand Down
8 changes: 4 additions & 4 deletions core/unit_test/TestTeamVector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ struct functor_vec_scan {

// Temporary: This condition will progressively be reduced when parallel_scan
// with return value will be implemented for more backends.
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_SYCL) && \
!defined(KOKKOS_ENABLE_OPENMPTARGET) && !defined(KOKKOS_ENABLE_HPX)
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_OPENMPTARGET) && \
!defined(KOKKOS_ENABLE_HPX)
template <typename Scalar, class ExecutionSpace>
struct functor_vec_scan_ret_val {
using policy_type = Kokkos::TeamPolicy<ExecutionSpace>;
Expand Down Expand Up @@ -735,8 +735,8 @@ bool test_scalar(int nteams, int team_size, int test) {
} else if (test == 12) {
// Temporary: This condition will progressively be reduced when parallel_scan
// with return value will be implemented for more backends.
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_SYCL) && \
!defined(KOKKOS_ENABLE_OPENMPTARGET) && !defined(KOKKOS_ENABLE_HPX)
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_OPENMPTARGET) && \
!defined(KOKKOS_ENABLE_HPX)
Kokkos::parallel_for(
Kokkos::TeamPolicy<ExecutionSpace>(nteams, team_size, 8),
functor_vec_scan_ret_val<Scalar, ExecutionSpace>(d_flag, team_size));
Expand Down

0 comments on commit c5bf870

Please sign in to comment.