Skip to content

Commit

Permalink
Improve SYCL parallel_scan
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Apr 28, 2023
1 parent d30b04d commit 3cc9915
Showing 1 changed file with 116 additions and 118 deletions.
234 changes: 116 additions & 118 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,22 @@ namespace Impl {
template <int dim, typename ValueType, typename FunctorType>
void workgroup_scan(sycl::nd_item<dim> item, const FunctorType& final_reducer,
sycl::local_accessor<ValueType> local_mem,
ValueType& local_value, unsigned int global_range) {
ValueType& local_value, int global_range) {
// subgroup scans
auto sg = item.get_sub_group();
const auto sg_group_id = sg.get_group_id()[0];
const auto id_in_sg = sg.get_local_id()[0];
for (unsigned int stride = 1; stride < global_range; stride <<= 1) {
auto sg = item.get_sub_group();
const int sg_group_id = sg.get_group_id()[0];
const int id_in_sg = sg.get_local_id()[0];

for (int stride = 1; stride < global_range; stride <<= 1) {
auto tmp = sg.shuffle_up(local_value, stride);
if (id_in_sg >= stride) final_reducer.join(&local_value, &tmp);
}

const auto max_subgroup_size = sg.get_max_local_range()[0];
const auto n_active_subgroups =
const int max_subgroup_size = sg.get_max_local_range()[0];
const int n_active_subgroups =
(global_range + max_subgroup_size - 1) / max_subgroup_size;

const auto local_range = sg.get_local_range()[0];
const int local_range = sg.get_local_range()[0];
if (id_in_sg == local_range - 1 && sg_group_id < n_active_subgroups)
local_mem[sg_group_id] = local_value;
local_value = sg.shuffle_up(local_value, 1);
Expand All @@ -56,14 +57,13 @@ void workgroup_scan(sycl::nd_item<dim> item, const FunctorType& final_reducer,
// scan subgroup results using the first subgroup
if (n_active_subgroups > 1) {
if (sg_group_id == 0) {
const auto n_rounds =
(n_active_subgroups + local_range - 1) / local_range;
for (unsigned int round = 0; round < n_rounds; ++round) {
const unsigned int idx = id_in_sg + round * local_range;
const int n_rounds = (n_active_subgroups + local_range - 1) / local_range;
for (int round = 0; round < n_rounds; ++round) {
const int idx = id_in_sg + round * local_range;
const auto upper_bound =
std::min(local_range, n_active_subgroups - round * local_range);
auto local_sg_value = local_mem[idx < n_active_subgroups ? idx : 0];
for (unsigned int stride = 1; stride < upper_bound; stride <<= 1) {
for (int stride = 1; stride < upper_bound; stride <<= 1) {
auto tmp = sg.shuffle_up(local_sg_value, stride);
if (id_in_sg >= stride) {
if (idx < n_active_subgroups)
Expand Down Expand Up @@ -123,14 +123,29 @@ class ParallelScanSYCLBase {

private:
template <typename FunctorWrapper>
void scan_internal(sycl::queue& q, const FunctorWrapper& functor_wrapper,
pointer_type global_mem, std::size_t size) const {
sycl::event sycl_direct_launch(const FunctorWrapper& functor_wrapper,
sycl::event memcpy_event) const {
// Convenience references
const Kokkos::Experimental::SYCL& space = m_policy.space();
Kokkos::Experimental::Impl::SYCLInternal& instance =
*space.impl_internal_space_instance();
sycl::queue& q = space.sycl_queue();

const auto size = m_policy.end() - m_policy.begin();

// FIXME_SYCL optimize
constexpr size_t wgroup_size = 128;
auto n_wgroups = (size + wgroup_size - 1) / wgroup_size;
auto global_mem = m_scratch_space;
pointer_type group_results = global_mem + n_wgroups * wgroup_size;

auto local_scans = q.submit([&](sycl::handler& cgh) {
auto scratch_flags = static_cast<sycl::device_ptr<unsigned int>>(
instance.scratch_flags(sizeof(unsigned int)));

// Initialize global memory
auto initialize_global_memory = q.submit([&](sycl::handler& cgh) {
auto begin = m_policy.begin();

// Store subgroup totals
const auto min_subgroup_size =
q.get_device()
Expand All @@ -140,99 +155,76 @@ class ParallelScanSYCLBase {
sycl::range<1>((wgroup_size + min_subgroup_size - 1) /
min_subgroup_size),
cgh);
sycl::local_accessor<unsigned int> num_teams_done(1, cgh);

cgh.depends_on(memcpy_event);

cgh.parallel_for(
sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
[=](sycl::nd_item<1> item) {
const CombinedFunctorReducer<
FunctorType, typename Analysis::Reducer>& functor_reducer =
functor_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const typename Analysis::Reducer& reducer =
functor_reducer.get_reducer();

const auto local_id = item.get_local_linear_id();
const auto global_id = item.get_global_linear_id();
const index_type local_id = item.get_local_linear_id();
const index_type global_id = item.get_global_linear_id();

// Initialize local memory
value_type local_value;
if (global_id < size)
local_value = global_mem[global_id];
else
reducer.init(&local_value);
reducer.init(&local_value);
if (global_id < size) {
if constexpr (std::is_void<WorkTag>::value)
functor(global_id + begin, local_value, false);
else
functor(WorkTag(), global_id + begin, local_value, false);
}

workgroup_scan<>(item, reducer, local_mem, local_value,
wgroup_size);

if (n_wgroups > 1 && local_id == wgroup_size - 1)
group_results[item.get_group_linear_id()] =
local_mem[item.get_sub_group().get_group_range()[0] - 1];

// Write results to global memory
if (global_id < size) global_mem[global_id] = local_value;
});
});
q.ext_oneapi_submit_barrier(std::vector<sycl::event>{local_scans});

if (n_wgroups > 1) {
scan_internal(q, functor_wrapper, group_results, n_wgroups);
auto update_with_group_results = q.submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
[=](sycl::nd_item<1> item) {
const auto global_id = item.get_global_linear_id();
const CombinedFunctorReducer<FunctorType,
typename Analysis::Reducer>
functor_reducer = functor_wrapper.get_functor();
const typename Analysis::Reducer& reducer =
functor_reducer.get_reducer();
if (global_id < size)
reducer.join(&global_mem[global_id],
&group_results[item.get_group_linear_id()]);
});
});
q.ext_oneapi_submit_barrier(
std::vector<sycl::event>{update_with_group_results});
}
}

template <typename FunctorWrapper>
sycl::event sycl_direct_launch(const FunctorWrapper& functor_wrapper,
sycl::event memcpy_event) const {
// Convenience references
const Kokkos::Experimental::SYCL& space = m_policy.space();
sycl::queue& q = space.sycl_queue();

const std::size_t len = m_policy.end() - m_policy.begin();

// Initialize global memory
auto initialize_global_memory = q.submit([&](sycl::handler& cgh) {
auto global_mem = m_scratch_space;
auto begin = m_policy.begin();
if (local_id == wgroup_size - 1) {
group_results[item.get_group_linear_id()] =
local_mem[item.get_sub_group().get_group_range()[0] - 1];

cgh.depends_on(memcpy_event);
cgh.parallel_for(sycl::range<1>(len), [=](sycl::item<1> item) {
const typename Policy::index_type id =
static_cast<typename Policy::index_type>(item.get_id()) + begin;
const CombinedFunctorReducer<FunctorType, typename Analysis::Reducer>&
functor_reducer = functor_wrapper.get_functor();
const typename Analysis::Reducer& reducer =
functor_reducer.get_reducer();

value_type update{};
reducer.init(&update);
const FunctorType& functor = functor_reducer.get_functor();
if constexpr (std::is_void<WorkTag>::value)
functor(id, update, false);
else
functor(WorkTag(), id, update, false);
global_mem[id] = update;
});
sycl::atomic_ref<unsigned, sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::global_space>
scratch_flags_ref(*scratch_flags);
num_teams_done[0] = ++scratch_flags_ref;
}
item.barrier(sycl::access::fence_space::global_space);
if (num_teams_done[0] == n_wgroups) {
value_type total;
reducer.init(&total);

for (unsigned int offset = 0; offset < n_wgroups;
offset += wgroup_size) {
index_type id = local_id + offset;
if (id < static_cast<index_type>(n_wgroups))
local_value = group_results[id];
else
reducer.init(&local_value);
workgroup_scan<>(item, reducer, local_mem, local_value,
std::min(n_wgroups - offset, wgroup_size));
if (id < static_cast<index_type>(n_wgroups)) {
reducer.join(&local_value, &total);
group_results[id] = local_value;
}
reducer.join(
&total,
&local_mem[item.get_sub_group().get_group_range()[0] - 1]);
if (offset + wgroup_size < n_wgroups)
item.barrier(sycl::access::fence_space::global_space);
}
}
});
});
q.ext_oneapi_submit_barrier(
std::vector<sycl::event>{initialize_global_memory});

// Perform the actual exclusive scan
scan_internal(q, functor_wrapper, m_scratch_space, len);

// Write results to global memory
auto update_global_results = q.submit([&](sycl::handler& cgh) {
Expand All @@ -241,21 +233,36 @@ class ParallelScanSYCLBase {
// The compiler failed with CL_INVALID_ARG_VALUE if using m_result_ptr
// directly.
auto result_ptr = m_result_ptr_device_accessible ? m_result_ptr : nullptr;
cgh.parallel_for(sycl::range<1>(len), [=](sycl::item<1> item) {
auto global_id = item.get_id(0);

value_type update = global_mem[global_id];
const CombinedFunctorReducer<FunctorType, typename Analysis::Reducer>&
functor_reducer = functor_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
if constexpr (std::is_void<WorkTag>::value)
functor(global_id, update, true);
else
functor(WorkTag(), global_id, update, true);
global_mem[global_id] = update;
if (global_id == len - 1 && result_ptr_device_accessible)
*result_ptr = update;
});
auto begin = m_policy.begin();

cgh.depends_on(initialize_global_memory);

cgh.parallel_for(
sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
[=](sycl::nd_item<1> item) {
const index_type global_id = item.get_global_linear_id();
const CombinedFunctorReducer<
FunctorType, typename Analysis::Reducer>& functor_reducer =
functor_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const typename Analysis::Reducer& reducer =
functor_reducer.get_reducer();

if (global_id < size) {
value_type update = global_mem[global_id];

reducer.join(&update, &group_results[item.get_group_linear_id()]);

if constexpr (std::is_void<WorkTag>::value)
functor(global_id + begin, update, true);
else
functor(WorkTag(), global_id + begin, update, true);

global_mem[global_id] = update;
if (global_id == size - 1 && result_ptr_device_accessible)
*result_ptr = update;
}
});
});
q.ext_oneapi_submit_barrier(
std::vector<sycl::event>{update_global_results});
Expand All @@ -270,22 +277,13 @@ class ParallelScanSYCLBase {
auto& instance = *m_policy.space().impl_internal_space_instance();
const std::size_t len = m_policy.end() - m_policy.begin();

// Compute the total amount of memory we will need. We emulate the recursive
// structure that is used to do the actual scan. Essentially, we need to
// allocate memory for the whole range and then recursively for the reduced
// group results until only one group is left.
std::size_t total_memory = 0;
{
size_t wgroup_size = 128;
size_t n_nested_size = len;
size_t n_nested_wgroups;
do {
n_nested_wgroups = (n_nested_size + wgroup_size - 1) / wgroup_size;
n_nested_size = n_nested_wgroups;
total_memory += sizeof(value_type) * n_nested_wgroups * wgroup_size;
} while (n_nested_wgroups > 1);
total_memory += sizeof(value_type) * wgroup_size;
}
// Compute the total amount of memory we will need.
// We need to allocate memory for the whole range (rounded towards the next
// multiple of the wqorkgroup size) and for one element per workgroup that
// will contain the sum of the previous workgroups totals.
size_t wgroup_size = 128;
size_t n_wgroups = (len + wgroup_size - 1) / wgroup_size;
size_t total_memory = n_wgroups * (wgroup_size + 1) * sizeof(value_type);

// FIXME_SYCL consider only storing one value per block and recreate initial
// results in the end before doing the final pass
Expand Down

0 comments on commit 3cc9915

Please sign in to comment.