Skip to content

Commit

Permalink
Pass local_accessor directly instead
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Mar 24, 2023
1 parent b0cc5a0 commit a6f27bf
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 30 deletions.
41 changes: 19 additions & 22 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ inline constexpr bool use_shuffle_based_algorithm =
namespace SYCLReduction {
template <typename ValueType, typename ReducerType, int dim>
std::enable_if_t<!use_shuffle_based_algorithm<ReducerType>> workgroup_reduction(
sycl::nd_item<dim>& item, sycl::local_ptr<ValueType> local_mem,
sycl::nd_item<dim>& item, sycl::local_accessor<ValueType> local_mem,
sycl::device_ptr<ValueType> results_ptr,
sycl::global_ptr<ValueType> device_accessible_result_ptr,
const unsigned int value_count, const ReducerType& final_reducer,
Expand Down Expand Up @@ -109,7 +109,7 @@ std::enable_if_t<!use_shuffle_based_algorithm<ReducerType>> workgroup_reduction(

template <typename ValueType, typename ReducerType, int dim>
std::enable_if_t<use_shuffle_based_algorithm<ReducerType>> workgroup_reduction(
sycl::nd_item<dim>& item, sycl::local_ptr<ValueType> local_mem,
sycl::nd_item<dim>& item, sycl::local_accessor<ValueType> local_mem,
ValueType local_value, sycl::device_ptr<ValueType> results_ptr,
sycl::global_ptr<ValueType> device_accessible_result_ptr,
const ReducerType& final_reducer, bool final, unsigned int max_size) {
Expand Down Expand Up @@ -271,8 +271,8 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
instance.scratch_flags(sizeof(unsigned int)));

auto reduction_lambda_factory =
[&](sycl::local_accessor<value_type, 1> local_mem,
sycl::local_accessor<unsigned int, 1> num_teams_done,
[&](sycl::local_accessor<value_type> local_mem,
sycl::local_accessor<unsigned int> num_teams_done,
sycl::device_ptr<value_type> results_ptr) {
const auto begin = policy.begin();

Expand Down Expand Up @@ -304,9 +304,8 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
item.barrier(sycl::access::fence_space::local_space);

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), results_ptr,
device_accessible_result_ptr, value_count, reducer, false,
std::min(size, wgroup_size));
item, local_mem, results_ptr, device_accessible_result_ptr,
value_count, reducer, false, std::min(size, wgroup_size));

if (local_id == 0) {
sycl::atomic_ref<unsigned, sycl::memory_order::relaxed,
Expand All @@ -330,7 +329,7 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), results_ptr,
item, local_mem, results_ptr,
device_accessible_result_ptr, value_count, reducer, true,
std::min(n_wgroups, wgroup_size));
}
Expand All @@ -346,7 +345,7 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), local_value, results_ptr,
item, local_mem, local_value, results_ptr,
device_accessible_result_ptr, reducer, false,
std::min(size, wgroup_size));

Expand All @@ -370,7 +369,7 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), local_value, results_ptr,
item, local_mem, local_value, results_ptr,
device_accessible_result_ptr, reducer, true,
std::min(n_wgroups, wgroup_size));
}
Expand All @@ -380,7 +379,7 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
};

auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) {
sycl::local_accessor<unsigned int, 1> num_teams_done(1, cgh);
sycl::local_accessor<unsigned int> num_teams_done(1, cgh);

auto dummy_reduction_lambda =
reduction_lambda_factory({1, cgh}, num_teams_done, nullptr);
Expand Down Expand Up @@ -421,7 +420,7 @@ class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
wgroup_size - 1) /
wgroup_size;

sycl::local_accessor<value_type, 1> local_mem(
sycl::local_accessor<value_type> local_mem(
sycl::range<1>(wgroup_size) * std::max(value_count, 1u), cgh);

cgh.depends_on(memcpy_events);
Expand Down Expand Up @@ -608,9 +607,9 @@ class ParallelReduce<CombinedFunctorReducerType,
if (size > 1) {
auto n_wgroups = (size + wgroup_size - 1) / wgroup_size;
auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) {
sycl::local_accessor<value_type, 1> local_mem(
sycl::local_accessor<value_type> local_mem(
sycl::range<1>(wgroup_size) * std::max(value_count, 1u), cgh);
sycl::local_accessor<unsigned int, 1> num_teams_done(1, cgh);
sycl::local_accessor<unsigned int> num_teams_done(1, cgh);

const BarePolicy bare_policy = m_policy;

Expand Down Expand Up @@ -652,9 +651,8 @@ class ParallelReduce<CombinedFunctorReducerType,
item.barrier(sycl::access::fence_space::local_space);

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), results_ptr,
device_accessible_result_ptr, value_count, reducer, false,
std::min(size, wgroup_size));
item, local_mem, results_ptr, device_accessible_result_ptr,
value_count, reducer, false, std::min(size, wgroup_size));

if (local_id == 0) {
sycl::atomic_ref<unsigned, sycl::memory_order::relaxed,
Expand All @@ -678,9 +676,8 @@ class ParallelReduce<CombinedFunctorReducerType,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), results_ptr,
device_accessible_result_ptr, value_count, reducer, true,
std::min(n_wgroups, wgroup_size));
item, local_mem, results_ptr, device_accessible_result_ptr,
value_count, reducer, true, std::min(n_wgroups, wgroup_size));
}
} else {
value_type local_value;
Expand All @@ -695,7 +692,7 @@ class ParallelReduce<CombinedFunctorReducerType,
.exec_range();

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), local_value, results_ptr,
item, local_mem, local_value, results_ptr,
device_accessible_result_ptr, reducer, false,
std::min(size, wgroup_size));

Expand All @@ -719,7 +716,7 @@ class ParallelReduce<CombinedFunctorReducerType,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), local_value, results_ptr,
item, local_mem, local_value, results_ptr,
device_accessible_result_ptr, reducer, true,
std::min(n_wgroups, wgroup_size));
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace Impl {
// total sum.
template <int dim, typename ValueType, typename FunctorType>
void workgroup_scan(sycl::nd_item<dim> item, const FunctorType& final_reducer,
sycl::local_ptr<ValueType> local_mem,
sycl::local_accessor<ValueType> local_mem,
ValueType& local_value, unsigned int global_range) {
// subgroup scans
auto sg = item.get_sub_group();
Expand Down Expand Up @@ -136,7 +136,7 @@ class ParallelScanSYCLBase {
q.get_device()
.template get_info<sycl::info::device::sub_group_sizes>()
.front();
sycl::local_accessor<value_type, 1> local_mem(
sycl::local_accessor<value_type> local_mem(
sycl::range<1>((wgroup_size + min_subgroup_size - 1) /
min_subgroup_size),
cgh);
Expand All @@ -160,8 +160,8 @@ class ParallelScanSYCLBase {
else
reducer.init(&local_value);

workgroup_scan<>(item, reducer, local_mem.get_pointer(),
local_value, wgroup_size);
workgroup_scan<>(item, reducer, local_mem, local_value,
wgroup_size);

if (n_wgroups > 1 && local_id == wgroup_size - 1)
group_results[item.get_group_linear_id()] =
Expand Down
8 changes: 4 additions & 4 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ class ParallelReduce<CombinedFunctorReducerType,
item.barrier(sycl::access::fence_space::local_space);

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), results_ptr,
item, local_mem, results_ptr,
device_accessible_result_ptr, value_count, reducer, false,
std::min<std::size_t>(size,
item.get_local_range()[0] *
Expand Down Expand Up @@ -696,7 +696,7 @@ class ParallelReduce<CombinedFunctorReducerType,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), results_ptr,
item, local_mem, results_ptr,
device_accessible_result_ptr, value_count, reducer,
true,
std::min(n_wgroups, item.get_local_range()[0] *
Expand All @@ -716,7 +716,7 @@ class ParallelReduce<CombinedFunctorReducerType,
functor(WorkTag(), team_member, update);

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), local_value, results_ptr,
item, local_mem, local_value, results_ptr,
device_accessible_result_ptr, reducer, false,
std::min<std::size_t>(size,
item.get_local_range()[0] *
Expand All @@ -742,7 +742,7 @@ class ParallelReduce<CombinedFunctorReducerType,
}

SYCLReduction::workgroup_reduction<>(
item, local_mem.get_pointer(), local_value, results_ptr,
item, local_mem, local_value, results_ptr,
device_accessible_result_ptr, reducer, true,
std::min(n_wgroups, item.get_local_range()[0] *
item.get_local_range()[1]));
Expand Down

0 comments on commit a6f27bf

Please sign in to comment.