Skip to content

Commit

Permalink
Only pass one wrapper object in SYCL reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Apr 24, 2023
1 parent c09dd1c commit a92e091
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 76 deletions.
100 changes: 50 additions & 50 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,11 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
p.space().impl_internal_space_instance()->m_mutexScratchSpace) {}

private:
template <typename PolicyType, typename FunctorWrapper,
typename ReducerWrapper>
template <typename PolicyType, typename CombinedFunctorReducerWrapper>
sycl::event sycl_direct_launch(
const PolicyType& policy, const FunctorWrapper& functor_wrapper,
const ReducerWrapper& reducer_wrapper,
const std::vector<sycl::event>& memcpy_events) const {
const PolicyType& policy,
const CombinedFunctorReducerWrapper& functor_reducer_wrapper,
const sycl::event& memcpy_event) const {
// Convenience references
const Kokkos::Experimental::SYCL& space = policy.space();
Kokkos::Experimental::Impl::SYCLInternal& instance =
Expand All @@ -241,11 +240,13 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,

auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) {
const auto begin = policy.begin();
cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);
cgh.single_task([=]() {
const FunctorType& functor = functor_wrapper.get_functor();
const ReducerType& reducer = reducer_wrapper.get_functor();
reference_type update = reducer.init(results_ptr);
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
reference_type update = reducer.init(results_ptr);
if (size == 1) {
if constexpr (std::is_void_v<WorkTag>)
functor(begin, update);
Expand Down Expand Up @@ -284,8 +285,10 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
const auto global_id =
wgroup_size * item.get_group_linear_id() * values_per_thread +
local_id;
const FunctorType& functor = functor_wrapper.get_functor();
const ReducerType& reducer = reducer_wrapper.get_functor();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();

using index_type = typename Policy::index_type;
const auto upper_bound = std::min<index_type>(
Expand Down Expand Up @@ -423,7 +426,7 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
sycl::local_accessor<value_type> local_mem(
sycl::range<1>(wgroup_size) * std::max(value_count, 1u), cgh);

cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);

auto reduction_lambda =
reduction_lambda_factory(local_mem, num_teams_done, results_ptr);
Expand Down Expand Up @@ -455,19 +458,16 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
*m_policy.space().impl_internal_space_instance();
using IndirectKernelMem =
Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem;
IndirectKernelMem& indirectKernelMem = instance.get_indirect_kernel_mem();
IndirectKernelMem& indirectReducerMem = instance.get_indirect_kernel_mem();

auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor_reducer.get_functor(), indirectKernelMem);
auto reducer_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor_reducer.get_reducer(), indirectReducerMem);

sycl::event event = sycl_direct_launch(
m_policy, functor_wrapper, reducer_wrapper,
{functor_wrapper.get_copy_event(), reducer_wrapper.get_copy_event()});
functor_wrapper.register_event(event);
reducer_wrapper.register_event(event);
IndirectKernelMem& indirectKernelMem = instance.get_indirect_kernel_mem();

auto functor_reducer_wrapper =
Experimental::Impl::make_sycl_function_wrapper(m_functor_reducer,
indirectKernelMem);

sycl::event event =
sycl_direct_launch(m_policy, functor_reducer_wrapper,
functor_reducer_wrapper.get_copy_event());
functor_reducer_wrapper.register_event(event);
}

private:
Expand Down Expand Up @@ -536,12 +536,11 @@ class ParallelReduce<CombinedFunctorReducerType,
m_space.impl_internal_space_instance()->m_mutexScratchSpace) {}

private:
template <typename PolicyType, typename FunctorWrapper,
typename ReducerWrapper>
template <typename PolicyType, typename CombinedFunctorReducerWrapper>
sycl::event sycl_direct_launch(
const PolicyType& policy, const FunctorWrapper& functor_wrapper,
const ReducerWrapper& reducer_wrapper,
const std::vector<sycl::event>& memcpy_events) const {
const PolicyType& policy,
const CombinedFunctorReducerWrapper& functor_reducer_wrapper,
const sycl::event& memcpy_event) const {
// Convenience references
Kokkos::Experimental::Impl::SYCLInternal& instance =
*m_space.impl_internal_space_instance();
Expand Down Expand Up @@ -577,10 +576,12 @@ class ParallelReduce<CombinedFunctorReducerType,
// m_result_ptr yet.
if (size <= 1) {
auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);
cgh.single_task([=]() {
const FunctorType& functor = functor_wrapper.get_functor();
const ReducerType& reducer = reducer_wrapper.get_functor();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();

reference_type update = reducer.init(results_ptr);
if (size == 1) {
Expand Down Expand Up @@ -613,12 +614,14 @@ class ParallelReduce<CombinedFunctorReducerType,

const BarePolicy bare_policy = m_policy;

cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);

cgh.parallel_for(range, [=](sycl::nd_item<1> item) {
const auto local_id = item.get_local_linear_id();
const FunctorType& functor = functor_wrapper.get_functor();
const ReducerType& reducer = reducer_wrapper.get_functor();
const auto local_id = item.get_local_linear_id();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();

// In the first iteration, we call functor to initialize the local
// memory. Otherwise, the local memory is initialized with the
Expand Down Expand Up @@ -751,19 +754,16 @@ class ParallelReduce<CombinedFunctorReducerType,
*m_space.impl_internal_space_instance();
using IndirectKernelMem =
Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem;
IndirectKernelMem& indirectKernelMem = instance.get_indirect_kernel_mem();
IndirectKernelMem& indirectReducerMem = instance.get_indirect_kernel_mem();

auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor_reducer.get_functor(), indirectKernelMem);
auto reducer_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor_reducer.get_reducer(), indirectReducerMem);

sycl::event event = sycl_direct_launch(
m_policy, functor_wrapper, reducer_wrapper,
{functor_wrapper.get_copy_event(), reducer_wrapper.get_copy_event()});
functor_wrapper.register_event(event);
reducer_wrapper.register_event(event);
IndirectKernelMem& indirectKernelMem = instance.get_indirect_kernel_mem();

auto functor_reducer_wrapper =
Experimental::Impl::make_sycl_function_wrapper(m_functor_reducer,
indirectKernelMem);

sycl::event event =
sycl_direct_launch(m_policy, functor_reducer_wrapper,
functor_reducer_wrapper.get_copy_event());
functor_reducer_wrapper.register_event(event);
}

private:
Expand Down
51 changes: 25 additions & 26 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
template <typename FunctorWrapper>
sycl::event sycl_direct_launch(const Policy& policy,
const FunctorWrapper& functor_wrapper,
const sycl::event& memcpy_events) const {
const sycl::event& memcpy_event) const {
// Convenience references
const Kokkos::Experimental::SYCL& space = policy.space();
sycl::queue& q = space.sycl_queue();
Expand Down Expand Up @@ -431,7 +431,7 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
// be used gives a runtime error.
// cgh.use_kernel_bundle(kernel_bundle);

cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);
cgh.parallel_for(
sycl::nd_range<2>(
sycl::range<2>(m_team_size, m_league_size * final_vector_size),
Expand Down Expand Up @@ -551,12 +551,11 @@ class ParallelReduce<CombinedFunctorReducerType,
std::scoped_lock<std::mutex> m_scratch_lock;
int m_scratch_pool_id = -1;

template <typename PolicyType, typename FunctorWrapper,
typename ReducerWrapper>
template <typename PolicyType, typename CombinedFunctorReducerWrapper>
sycl::event sycl_direct_launch(
const PolicyType& policy, const FunctorWrapper& functor_wrapper,
const ReducerWrapper& reducer_wrapper,
const std::vector<sycl::event>& memcpy_events) const {
const PolicyType& policy,
const CombinedFunctorReducerWrapper& functor_reducer_wrapper,
const sycl::event& memcpy_event) const {
// Convenience references
const Kokkos::Experimental::SYCL& space = policy.space();
Kokkos::Experimental::Impl::SYCLInternal& instance =
Expand Down Expand Up @@ -593,12 +592,14 @@ class ParallelReduce<CombinedFunctorReducerType,
const size_t scratch_size[2] = {m_scratch_size[0], m_scratch_size[1]};
sycl::device_ptr<char> const global_scratch_ptr = m_global_scratch_ptr;

cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);
cgh.parallel_for(
sycl::nd_range<2>(sycl::range<2>(1, 1), sycl::range<2>(1, 1)),
[=](sycl::nd_item<2> item) {
const FunctorType& functor = functor_wrapper.get_functor();
const ReducerType& reducer = reducer_wrapper.get_functor();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();

reference_type update = reducer.init(results_ptr);
if (size == 1) {
Expand Down Expand Up @@ -653,9 +654,11 @@ class ParallelReduce<CombinedFunctorReducerType,

auto& num_teams_done = reinterpret_cast<unsigned int&>(
local_mem[wgroup_size * std::max(value_count, 1u)]);
const auto local_id = item.get_local_linear_id();
const FunctorType& functor = functor_wrapper.get_functor();
const ReducerType& reducer = reducer_wrapper.get_functor();
const auto local_id = item.get_local_linear_id();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();

if constexpr (ReducerType::static_value_size() == 0) {
reference_type update =
Expand Down Expand Up @@ -791,7 +794,7 @@ class ParallelReduce<CombinedFunctorReducerType,

auto reduction_lambda = team_reduction_factory(local_mem, results_ptr);

cgh.depends_on(memcpy_events);
cgh.depends_on(memcpy_event);

cgh.parallel_for(
sycl::nd_range<2>(
Expand Down Expand Up @@ -822,20 +825,16 @@ class ParallelReduce<CombinedFunctorReducerType,
*m_policy.space().impl_internal_space_instance();
using IndirectKernelMem =
Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem;
IndirectKernelMem& indirectKernelMem = instance.get_indirect_kernel_mem();
IndirectKernelMem& indirectReducerMem = instance.get_indirect_kernel_mem();
IndirectKernelMem& indirectKernelMem = instance.get_indirect_kernel_mem();

auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor_reducer.get_functor(), indirectKernelMem);
auto reducer_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor_reducer.get_reducer(), indirectReducerMem);

sycl::event event = sycl_direct_launch(
m_policy, functor_wrapper, reducer_wrapper,
{functor_wrapper.get_copy_event(), reducer_wrapper.get_copy_event()});
functor_wrapper.register_event(event);
reducer_wrapper.register_event(event);
auto functor_reducer_wrapper =
Experimental::Impl::make_sycl_function_wrapper(m_functor_reducer,
indirectKernelMem);

sycl::event event =
sycl_direct_launch(m_policy, functor_reducer_wrapper,
functor_reducer_wrapper.get_copy_event());
functor_reducer_wrapper.register_event(event);
instance.register_team_scratch_event(m_scratch_pool_id, event);
}

Expand Down

0 comments on commit a92e091

Please sign in to comment.