Skip to content

Commit

Permalink
Cleanup OpenMPTaget ParallelReduce
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Mar 17, 2023
1 parent 65aa95e commit d5244e1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
using pointer_type = typename ReducerType::pointer_type;
using reference_type = typename ReducerType::reference_type;

static constexpr int FunctorHasJoin =
static constexpr bool FunctorHasJoin =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, Policy,
FunctorType>::has_join_member_function;
static constexpr int UseReducer =
FunctorType>::Reducer::has_join_member_function();
static constexpr bool UseReducer =
!std::is_same_v<FunctorType, typename ReducerType::functor_type>;
static constexpr int IsArray = std::is_pointer<reference_type>::value;
static constexpr bool IsArray = std::is_pointer_v<reference_type>;

using ParReduceSpecialize =
ParallelReduceSpecialize<FunctorType, Policy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,12 @@ class ParallelReduce<CombinedFunctorReducerType,
bool m_result_ptr_on_device;
const int m_result_ptr_num_elems;

static constexpr int FunctorHasJoin =
static constexpr bool FunctorHasJoin =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, Policy,
FunctorType>::has_join_member_function;
static constexpr int UseReducer =
FunctorType>::Reducer::has_join_member_function();
static constexpr bool UseReducer =
!std::is_same_v<FunctorType, typename ReducerType::functor_type>;
static constexpr int IsArray = std::is_pointer<reference_type>::value;
static constexpr bool IsArray = std::is_pointer_v<reference_type>;

using ParReduceSpecialize =
ParallelReduceSpecialize<FunctorType, Policy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ParallelScan<FunctorType, Kokkos::RangePolicy<Traits...>,
local_offset_value = element_values(team_id, i - 1);
// FIXME_OPENMPTARGET We seem to access memory illegaly on AMD GPUs
#ifdef KOKKOS_ARCH_VEGA
if constexpr (Analysis::has_join_member_function) {
if constexpr (Analysis::Reducer::has_join_member_function()) {
if constexpr (std::is_void_v<WorkTag>)
a_functor.join(local_offset_value, offset_value);
else
Expand Down
66 changes: 17 additions & 49 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_Parallel_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ struct ParallelReduceSpecialize {
PointerType /*result_ptr*/) {
constexpr int FunctorHasJoin =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, PolicyType,
FunctorType>::has_join_member_function;
constexpr int UseReducerType = is_reducer<ReducerType>::value;
FunctorType>::Reducer::has_join_member_function();
constexpr int UseReducerType = is_reducer_v<ReducerType>;

std::stringstream error_message;
error_message << "Error: Invalid Specialization " << FunctorHasJoin << ' '
Expand Down Expand Up @@ -198,7 +198,6 @@ struct ParallelReduceSpecialize<FunctorType, Kokkos::RangePolicy<PolicyArgs...>,
using FunctorAnalysis =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, PolicyType,
FunctorType>;
constexpr int HasInit = FunctorAnalysis::has_init_member_function;

// Initialize the result pointer.

Expand All @@ -220,27 +219,16 @@ struct ParallelReduceSpecialize<FunctorType, Kokkos::RangePolicy<PolicyArgs...>,
ValueType* scratch_ptr =
static_cast<ValueType*>(OpenMPTargetExec::get_scratch_ptr());

#pragma omp target map(to : f) is_device_ptr(scratch_ptr)
{
typename FunctorAnalysis::Reducer final_reducer(f);
// Enter this loop if the functor has an `init`
if constexpr (HasInit) {
// The `init` routine needs to be called on the device since it might
// need device members.
final_reducer.init(scratch_ptr);
final_reducer.final(scratch_ptr);
} else {
for (int i = 0; i < value_count; ++i) {
static_cast<ValueType*>(scratch_ptr)[i] = ValueType();
}
typename FunctorAnalysis::Reducer final_reducer(f);

if (end <= begin) {
#pragma omp target map(to : final_reducer) is_device_ptr(scratch_ptr)
{
// If there is no work to be done, copy back the initialized values and
// exit.
final_reducer.init(scratch_ptr);
final_reducer.final(scratch_ptr);
}
}

if (end <= begin) {
// If there is no work to be done, copy back the initialized values and
// exit.
if (!ptr_on_device)
KOKKOS_IMPL_OMPT_SAFE_CALL(omp_target_memcpy(
ptr, scratch_ptr, value_count * sizeof(ValueType), 0, 0,
Expand All @@ -255,9 +243,8 @@ struct ParallelReduceSpecialize<FunctorType, Kokkos::RangePolicy<PolicyArgs...>,

#pragma omp target teams num_teams(max_teams) thread_limit(max_team_threads) \
map(to \
: f) is_device_ptr(scratch_ptr)
: final_reducer) is_device_ptr(scratch_ptr)
{
typename FunctorAnalysis::Reducer final_reducer(f);
#pragma omp parallel
{
const int team_num = omp_get_team_num();
Expand Down Expand Up @@ -304,7 +291,6 @@ struct ParallelReduceSpecialize<FunctorType, Kokkos::RangePolicy<PolicyArgs...>,
is_device_ptr(scratch_ptr)
for (int i = 0; i < max_teams - tree_neighbor_offset;
i += 2 * tree_neighbor_offset) {
typename FunctorAnalysis::Reducer final_reducer(f);
ValueType* team_scratch = scratch_ptr;
const int team_offset = max_team_threads * value_count;
final_reducer.join(
Expand Down Expand Up @@ -538,7 +524,6 @@ struct ParallelReduceSpecialize<FunctorType, TeamPolicyInternal<PolicyArgs...>,
using FunctorAnalysis =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, PolicyType,
FunctorType>;
constexpr int HasInit = FunctorAnalysis::has_init_member_function;

const int league_size = p.league_size();
const int team_size = p.team_size();
Expand Down Expand Up @@ -568,32 +553,17 @@ struct ParallelReduceSpecialize<FunctorType, TeamPolicyInternal<PolicyArgs...>,
OpenMPTargetExec::resize_scratch(1, 0, value_count * sizeof(ValueType),
league_size);
void* scratch_ptr = OpenMPTargetExec::get_scratch_ptr();
typename FunctorAnalysis::Reducer final_reducer(f);

// Enter this loop if the functor has an `init`
if constexpr (HasInit) {
// The `init` routine needs to be called on the device since it might need
// device members.
#pragma omp target map(to : f) is_device_ptr(scratch_ptr)
if (end <= begin) {
// If there is no work to be done, copy back the initialized values and
// exit.
#pragma omp target map(to : final_reducer) is_device_ptr(scratch_ptr)
{
typename FunctorAnalysis::Reducer final_reducer(f);
final_reducer.init(scratch_ptr);
final_reducer.final(scratch_ptr);
}
} else {
#pragma omp target map(to : f) is_device_ptr(scratch_ptr)
{
for (int i = 0; i < value_count; ++i) {
static_cast<ValueType*>(scratch_ptr)[i] = ValueType();
}

typename FunctorAnalysis::Reducer final_reducer(f);
final_reducer.final(static_cast<ValueType*>(scratch_ptr));
}
}

if (end <= begin) {
// If there is no work to be done, copy back the initialized values and
// exit.
if (!ptr_on_device)
KOKKOS_IMPL_OMPT_SAFE_CALL(omp_target_memcpy(
ptr, scratch_ptr, value_count * sizeof(ValueType), 0, 0,
Expand All @@ -616,7 +586,6 @@ struct ParallelReduceSpecialize<FunctorType, TeamPolicyInternal<PolicyArgs...>,
const int num_teams = omp_get_num_teams();
ValueType* team_scratch = static_cast<ValueType*>(scratch_ptr) +
team_num * team_size * value_count;
typename FunctorAnalysis::Reducer final_reducer(f);
ReferenceType result = final_reducer.init(&team_scratch[0]);

for (int league_id = team_num; league_id < league_size;
Expand All @@ -635,14 +604,13 @@ struct ParallelReduceSpecialize<FunctorType, TeamPolicyInternal<PolicyArgs...>,

int tree_neighbor_offset = 1;
do {
#pragma omp target teams distribute parallel for simd map(to \
: f) \
#pragma omp target teams distribute parallel for simd map(to \
: final_reducer) \
is_device_ptr(scratch_ptr)
for (int i = 0; i < nteams - tree_neighbor_offset;
i += 2 * tree_neighbor_offset) {
ValueType* team_scratch = static_cast<ValueType*>(scratch_ptr);
const int team_offset = team_size * value_count;
typename FunctorAnalysis::Reducer final_reducer(f);
final_reducer.join(
&team_scratch[i * team_offset],
&team_scratch[(i + tree_neighbor_offset) * team_offset]);
Expand Down

0 comments on commit d5244e1

Please sign in to comment.