Skip to content

Commit

Permalink
kokkos#5635: Add parallel_scan changes for CUDA and TeamThreadRange
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable authored and cz4rs committed Sep 14, 2023
1 parent 6a95b5f commit 1fb6f4a
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions core/src/Cuda/Kokkos_Cuda_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,28 +689,30 @@ parallel_reduce(Impl::ThreadVectorRangeBoundariesStruct<
* final == true.
*/
// This is the same code as in HIP and largely the same as in OpenMPTarget
template <typename iType, typename FunctorType>
template <typename iType, typename FunctorType, typename ValueType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<iType, Impl::CudaTeamMember>&
loop_bounds,
const FunctorType& lambda) {
// Extract value_type from lambda
using value_type = typename Kokkos::Impl::FunctorAnalysis<
const FunctorType& lambda, ValueType& return_val) {
// Extract ValueType from the Functor
using functor_value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;
static_assert(std::is_same<functor_value_type, ValueType>::value,
"Non-matching value types of functor and return type");

const auto start = loop_bounds.start;
const auto end = loop_bounds.end;
auto& member = loop_bounds.member;
const auto team_size = member.team_size();
const auto team_rank = member.team_rank();
const auto nchunk = (end - start + team_size - 1) / team_size;
value_type accum = 0;
ValueType accum = 0;
// each team has to process one or more chunks of the prefix scan
for (iType i = 0; i < nchunk; ++i) {
auto ii = start + i * team_size + team_rank;
// local accumulation for this chunk
value_type local_accum = 0;
ValueType local_accum = 0;
// user updates value with prefix value
if (ii < loop_bounds.end) lambda(ii, local_accum, false);
// perform team scan
Expand All @@ -724,6 +726,29 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
// broadcast last value to rest of the team
member.team_broadcast(accum, team_size - 1);
}
return_val = accum;
}

/** \brief Inter-thread parallel exclusive prefix sum.
*
* Executes closure(iType i, ValueType & val, bool final) for each i=[0..N)
*
* The range [0..N) is mapped to each rank in the team (whose global rank is
* less than N) and a scan operation is performed. The last call to closure has
* final == true.
*/
template <typename iType, typename FunctorType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<iType, Impl::CudaTeamMember>&
loop_bounds,
const FunctorType& lambda) {
// Extract value_type from functor
using value_type = typename Kokkos::Impl::FunctorAnalysis<
Kokkos::Impl::FunctorPatternInterface::SCAN, void, FunctorType,
void>::value_type;

value_type dummy;
parallel_scan(loop_bounds, lambda, dummy);
}

//----------------------------------------------------------------------------
Expand Down

0 comments on commit 1fb6f4a

Please sign in to comment.