diff --git a/core/src/HIP/Kokkos_HIP_Parallel_MDRange.hpp b/core/src/HIP/Kokkos_HIP_Parallel_MDRange.hpp index f6b4954dec3..92e02637467 100644 --- a/core/src/HIP/Kokkos_HIP_Parallel_MDRange.hpp +++ b/core/src/HIP/Kokkos_HIP_Parallel_MDRange.hpp @@ -188,6 +188,23 @@ class ParallelReduce 4 bytes in size, indexing into shared/global memory relies + // on the block and grid dimensions to ensure that we index at the correct + // offset rather than at every 4 byte word; such that, when the join is + // performed, we have the correct data that was copied over in chunks of 4 + // bytes. + using word_size_type = std::conditional_t< + sizeof(value_type) < sizeof(Kokkos::HIP::size_type), + std::conditional_t, + Kokkos::HIP::size_type>; + // Algorithmic constraints: blockSize is a power of two AND blockDim.y == // blockDim.z == 1 @@ -195,8 +212,8 @@ class ParallelReduce; @@ -211,12 +228,12 @@ class ParallelReduce - word_count(reducer.value_size() / sizeof(size_type)); + word_size_type, ReducerType::static_value_size() / sizeof(word_size_type)> + word_count(reducer.value_size() / sizeof(word_size_type)); { reference_type value = reducer.init(reinterpret_cast( - kokkos_impl_hip_shared_memory() + + kokkos_impl_hip_shared_memory() + threadIdx.y * word_count.value)); // Number of blocks is bounded so that the reduction can be limited to two @@ -232,14 +249,14 @@ class ParallelReduce( reducer, blockIdx.x, gridDim.x, - kokkos_impl_hip_shared_memory(), m_scratch_space, + kokkos_impl_hip_shared_memory(), m_scratch_space, m_scratch_flags)) { // This is the final block with the final result at the final threads' // location - size_type* const shared = kokkos_impl_hip_shared_memory() + + word_size_type* const shared = kokkos_impl_hip_shared_memory() + (blockDim.y - 1) * word_count.value; - size_type* const global = m_result_ptr_device_accessible - ? reinterpret_cast(m_result_ptr) + word_size_type* const global = m_result_ptr_device_accessible + ? reinterpret_cast(m_result_ptr) : m_scratch_space; if (threadIdx.y == 0) { @@ -294,7 +311,7 @@ class ParallelReduce