Skip to content

Commit

Permalink
Explicitly cast to CombinedFunctorReducerType
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Apr 24, 2023
1 parent a92e091 commit 2e51c67
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
34 changes: 17 additions & 17 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
const auto begin = policy.begin();
cgh.depends_on(memcpy_event);
cgh.single_task([=]() {
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
reference_type update = reducer.init(results_ptr);
const CombinedFunctorReducerType& functor_reducer =
functor_reducer_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const ReducerType& reducer = functor_reducer.get_reducer();
reference_type update = reducer.init(results_ptr);
if (size == 1) {
if constexpr (std::is_void_v<WorkTag>)
functor(begin, update);
Expand Down Expand Up @@ -285,10 +285,10 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
const auto global_id =
wgroup_size * item.get_group_linear_id() * values_per_thread +
local_id;
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
const CombinedFunctorReducerType& functor_reducer =
functor_reducer_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const ReducerType& reducer = functor_reducer.get_reducer();

using index_type = typename Policy::index_type;
const auto upper_bound = std::min<index_type>(
Expand Down Expand Up @@ -578,10 +578,10 @@ class ParallelReduce<CombinedFunctorReducerType,
auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on(memcpy_event);
cgh.single_task([=]() {
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
const CombinedFunctorReducerType& functor_reducer =
functor_reducer_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const ReducerType& reducer = functor_reducer.get_reducer();

reference_type update = reducer.init(results_ptr);
if (size == 1) {
Expand Down Expand Up @@ -618,10 +618,10 @@ class ParallelReduce<CombinedFunctorReducerType,

cgh.parallel_for(range, [=](sycl::nd_item<1> item) {
const auto local_id = item.get_local_linear_id();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
const CombinedFunctorReducerType& functor_reducer =
functor_reducer_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const ReducerType& reducer = functor_reducer.get_reducer();

// In the first iteration, we call functor to initialize the local
// memory. Otherwise, the local memory is initialized with the
Expand Down
16 changes: 8 additions & 8 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,10 @@ class ParallelReduce<CombinedFunctorReducerType,
cgh.parallel_for(
sycl::nd_range<2>(sycl::range<2>(1, 1), sycl::range<2>(1, 1)),
[=](sycl::nd_item<2> item) {
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
const CombinedFunctorReducerType& functor_reducer =
functor_reducer_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const ReducerType& reducer = functor_reducer.get_reducer();

reference_type update = reducer.init(results_ptr);
if (size == 1) {
Expand Down Expand Up @@ -655,10 +655,10 @@ class ParallelReduce<CombinedFunctorReducerType,
auto& num_teams_done = reinterpret_cast<unsigned int&>(
local_mem[wgroup_size * std::max(value_count, 1u)]);
const auto local_id = item.get_local_linear_id();
const FunctorType& functor =
functor_reducer_wrapper.get_functor().get_functor();
const ReducerType& reducer =
functor_reducer_wrapper.get_functor().get_reducer();
const CombinedFunctorReducerType& functor_reducer =
functor_reducer_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const ReducerType& reducer = functor_reducer.get_reducer();

if constexpr (ReducerType::static_value_size() == 0) {
reference_type update =
Expand Down

0 comments on commit 2e51c67

Please sign in to comment.