Skip to content

Commit

Permalink
#5641: HIP: Fix MDRange parallel_reduce over values smaller than int
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilMiller committed Mar 2, 2023
1 parent 854e264 commit 00f951a
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions core/src/HIP/Kokkos_HIP_Parallel_MDRange.hpp
Expand Up @@ -188,15 +188,32 @@ class ParallelReduce<CombinedFunctorReducerType,
using functor_type = FunctorType;
using size_type = HIP::size_type;

// Conditionally set word_size_type to int16_t or int8_t if value_type is
// smaller than int32_t (Kokkos::HIP::size_type)
// word_size_type is used to determine the word count, shared memory buffer
// size, and global memory buffer size before the reduction is performed.
// Within the reduction, the word count is recomputed based on word_size_type
// and when calculating indexes into the shared/global memory buffers for
// performing the reduction, word_size_type is used again.
// For scalars > 4 bytes in size, indexing into shared/global memory relies
// on the block and grid dimensions to ensure that we index at the correct
// offset rather than at every 4 byte word; such that, when the join is
// performed, we have the correct data that was copied over in chunks of 4
// bytes.
using word_size_type = std::conditional_t<
sizeof(value_type) < sizeof(Kokkos::HIP::size_type),
std::conditional_t<sizeof(value_type) == 2, int16_t, int8_t>,
Kokkos::HIP::size_type>;

// Algorithmic constraints: blockSize is a power of two AND blockDim.y ==
// blockDim.z == 1

const CombinedFunctorReducerType m_functor_reducer;
const Policy m_policy; // used for workrange and nwork
const pointer_type m_result_ptr;
const bool m_result_ptr_device_accessible;
size_type* m_scratch_space;
size_type* m_scratch_flags;
word_size_type* m_scratch_space;
HIP::size_type* m_scratch_flags;

using DeviceIteratePattern = typename Kokkos::Impl::Reduce::DeviceIterateTile<
Policy::rank, Policy, FunctorType, WorkTag, reference_type>;
Expand All @@ -211,12 +228,12 @@ class ParallelReduce<CombinedFunctorReducerType,
const ReducerType& reducer = m_functor_reducer.get_reducer();

const integral_nonzero_constant<
size_type, ReducerType::static_value_size() / sizeof(size_type)>
word_count(reducer.value_size() / sizeof(size_type));
word_size_type, ReducerType::static_value_size() / sizeof(word_size_type)>
word_count(reducer.value_size() / sizeof(word_size_type));

{
reference_type value = reducer.init(reinterpret_cast<pointer_type>(
kokkos_impl_hip_shared_memory<size_type>() +
kokkos_impl_hip_shared_memory<word_size_type>() +
threadIdx.y * word_count.value));

// Number of blocks is bounded so that the reduction can be limited to two
Expand All @@ -232,14 +249,14 @@ class ParallelReduce<CombinedFunctorReducerType,
// Problem: non power-of-two blockDim
if (::Kokkos::Impl::hip_single_inter_block_reduce_scan<false>(
reducer, blockIdx.x, gridDim.x,
kokkos_impl_hip_shared_memory<size_type>(), m_scratch_space,
kokkos_impl_hip_shared_memory<word_size_type>(), m_scratch_space,
m_scratch_flags)) {
// This is the final block with the final result at the final threads'
// location
size_type* const shared = kokkos_impl_hip_shared_memory<size_type>() +
word_size_type* const shared = kokkos_impl_hip_shared_memory<word_size_type>() +
(blockDim.y - 1) * word_count.value;
size_type* const global = m_result_ptr_device_accessible
? reinterpret_cast<size_type*>(m_result_ptr)
word_size_type* const global = m_result_ptr_device_accessible
? reinterpret_cast<word_size_type*>(m_result_ptr)
: m_scratch_space;

if (threadIdx.y == 0) {
Expand Down Expand Up @@ -294,7 +311,7 @@ class ParallelReduce<CombinedFunctorReducerType,
: suggested_blocksize; // Note: block_size must be less
// than or equal to 512

m_scratch_space = hip_internal_scratch_space(
m_scratch_space = (word_size_type*)hip_internal_scratch_space(
m_policy.space(), reducer.value_size() *
block_size /* block_size == max block_count */);
m_scratch_flags =
Expand Down

0 comments on commit 00f951a

Please sign in to comment.