Skip to content

Commit

Permalink
OpenMPTarget: Guard scratch memory usage in ParallelReduce
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Nov 8, 2023
1 parent 6fc7a49 commit fcb0452
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_Exec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ int* OpenMPTargetExec::m_lock_array = nullptr;
uint64_t OpenMPTargetExec::m_lock_size = 0;
uint32_t* OpenMPTargetExec::m_uniquetoken_ptr = nullptr;
int OpenMPTargetExec::MAX_ACTIVE_THREADS = 0;
std::mutex OpenMPTargetExec::m_mutex_scratch_ptr;

void OpenMPTargetExec::clear_scratch() {
Kokkos::Experimental::OpenMPTargetSpace space;
Expand Down
1 change: 1 addition & 0 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_Parallel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ class OpenMPTargetExec {
int64_t thread_local_bytes, int64_t league_size);

static void* m_scratch_ptr;
static std::mutex m_mutex_scratch_ptr;
static int64_t m_scratch_size;
static int* m_lock_array;
static uint64_t m_lock_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
const pointer_type m_result_ptr;
bool m_result_ptr_on_device;
const int m_result_ptr_num_elems;
// Only let one ParallelReduce instance at a time use the scratch memory.
// The constructor acquires the mutex which is released in the destructor.
std::scoped_lock<std::mutex> m_scratch_memory_lock;
using TagType = typename Policy::work_tag;

public:
Expand Down Expand Up @@ -105,7 +108,8 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
m_result_ptr_on_device(
MemorySpaceAccess<Kokkos::Experimental::OpenMPTargetSpace,
typename ViewType::memory_space>::accessible),
m_result_ptr_num_elems(arg_result_view.size()) {}
m_result_ptr_num_elems(arg_result_view.size()),
m_scratch_memory_lock(OpenMPTargetExec::m_mutex_scratch_ptr) {}
};

} // namespace Impl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ class ParallelReduce<CombinedFunctorReducerType,
const pointer_type m_result_ptr;
const size_t m_shmem_size;

// Only let one ParallelReduce instance at a time use the scratch memory.
// The constructor acquires the mutex which is released in the destructor.
std::scoped_lock<std::mutex> m_scratch_memory_lock;

public:
void execute() const {
const FunctorType& functor = m_functor_reducer.get_functor();
Expand Down Expand Up @@ -517,7 +521,8 @@ class ParallelReduce<CombinedFunctorReducerType,
m_shmem_size(
arg_policy.scratch_size(0) + arg_policy.scratch_size(1) +
FunctorTeamShmemSize<FunctorType>::value(
arg_functor_reducer.get_functor(), arg_policy.team_size())) {}
arg_functor_reducer.get_functor(), arg_policy.team_size())),
m_scratch_memory_lock(OpenMPTargetExec::m_mutex_scratch_ptr) {}
};

} // namespace Impl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ class ParallelReduce<CombinedFunctorReducerType,

bool m_result_ptr_on_device;

// Only let one ParallelReduce instance at a time use the scratch memory.
// The constructor acquires the mutex which is released in the destructor.
std::scoped_lock<std::mutex> m_scratch_memory_lock;

public:
inline void execute() const {
execute_tile<Policy::rank, typename ReducerType::value_type>(
Expand All @@ -452,7 +456,8 @@ class ParallelReduce<CombinedFunctorReducerType,
m_policy(arg_policy),
m_result_ptr_on_device(
MemorySpaceAccess<Kokkos::Experimental::OpenMPTargetSpace,
typename ViewType::memory_space>::accessible) {}
typename ViewType::memory_space>::accessible),
m_scratch_memory_lock(OpenMPTargetExec::m_mutex_scratch_ptr) {}

template <int Rank, class ValueType>
inline std::enable_if_t<Rank == 2> execute_tile(const FunctorType& functor,
Expand Down

0 comments on commit fcb0452

Please sign in to comment.