Skip to content

Commit

Permalink
add overload for TeamThreadRange
Browse files Browse the repository at this point in the history
  • Loading branch information
fnrizzi committed Oct 10, 2023
1 parent d97f16f commit fdfeaf9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
33 changes: 26 additions & 7 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelScan_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@
namespace Kokkos {

// This is largely the same code as in HIP and CUDA except for the member name
template <typename iType, class FunctorType>
template <typename iType, class FunctorType, class ValueType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<
iType, Impl::OpenMPTargetExecTeamMember>& loop_bounds,
const FunctorType& lambda) {
using Analysis = Impl::FunctorAnalysis<Impl::FunctorPatternInterface::SCAN,
const FunctorType& lambda, ValueType& return_val) {
using Analysis = Impl::FunctorAnalysis<Impl::FunctorPatternInterface::SCAN,
TeamPolicy<Experimental::OpenMPTarget>,
FunctorType, void>;
using value_type = typename Analysis::value_type;
using analysis_value_type = typename Analysis::value_type;
static_assert(std::is_same_v<analysis_value_type, ValueType>,
"Non-matching value types of functor and return type");

const auto start = loop_bounds.start;
const auto end = loop_bounds.end;
Expand All @@ -50,24 +52,27 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
const auto team_rank = member.team_rank();

#if defined(KOKKOS_IMPL_TEAM_SCAN_WORKAROUND)
value_type scan_val = value_type();
ValueType scan_val = {};

if (team_rank == 0) {
for (iType i = start; i < end; ++i) {
lambda(i, scan_val, true);
}
}
member.team_broadcast(scan_val, 0);
return_val = scan_val;

#pragma omp barrier
#else
const auto team_size = member.team_size();
const auto nchunk = (end - start + team_size - 1) / team_size;
value_type accum = 0;
ValueType accum = {};
// each team has to process one or
// more chunks of the prefix scan
for (iType i = 0; i < nchunk; ++i) {
auto ii = start + i * team_size + team_rank;
// local accumulation for this chunk
value_type local_accum = 0;
ValueType local_accum = {};
// user updates value with prefix value
if (ii < loop_bounds.end) lambda(ii, local_accum, false);
// perform team scan
Expand All @@ -81,9 +86,23 @@ KOKKOS_INLINE_FUNCTION void parallel_scan(
// broadcast last value to rest of the team
member.team_broadcast(accum, team_size - 1);
}
return_val = accum;

#endif
}

template <typename iType, class FunctorType>
KOKKOS_INLINE_FUNCTION void parallel_scan(
const Impl::TeamThreadRangeBoundariesStruct<
iType, Impl::OpenMPTargetExecTeamMember>& loop_bounds,
const FunctorType& lambda) {
using Analysis = Impl::FunctorAnalysis<Impl::FunctorPatternInterface::SCAN,
TeamPolicy<Experimental::OpenMPTarget>,
FunctorType, void>;
using value_type = typename Analysis::value_type;
value_type scan_val;
parallel_scan(loop_bounds, lambda, scan_val);
}
} // namespace Kokkos

namespace Kokkos {
Expand Down
3 changes: 1 addition & 2 deletions core/unit_test/TestTeamScan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ TEST(TEST_CATEGORY, team_scan) {

// Temporary: This condition will progressively be reduced when parallel_scan
// with return value will be implemented for more backends.
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_OPENMPTARGET) && \
!defined(KOKKOS_ENABLE_HPX)
#if !defined(KOKKOS_ENABLE_OPENACC) && !defined(KOKKOS_ENABLE_HPX)
template <class ExecutionSpace, class DataType>
struct TestTeamScanRetVal {
using execution_space = ExecutionSpace;
Expand Down

0 comments on commit fdfeaf9

Please sign in to comment.