Skip to content

Commit

Permalink
Compiling with auto deduction of workgroup sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Apr 28, 2023
1 parent 3cc9915 commit bdaa12c
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 53 deletions.
8 changes: 3 additions & 5 deletions core/src/SYCL/Kokkos_SYCL_Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,9 @@ sycl::device_ptr<void> SYCLInternal::scratch_flags(const std::size_t size) {

m_scratchFlags = reinterpret_cast<size_type*>(r->data());
}
m_queue->memset(m_scratchFlags, 0, m_scratchFlagsCount * sizeScratchGrain);
fence(*m_queue,
"Kokkos::Experimental::SYCLInternal::scratch_flags fence after "
"initializing m_scratchFlags",
m_instance_id);
auto memset_event = m_queue->memset(m_scratchFlags, 0,
m_scratchFlagsCount * sizeScratchGrain);
m_queue->ext_oneapi_submit_barrier(std::vector{memset_event});

return m_scratchFlags;
}
Expand Down
132 changes: 84 additions & 48 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ParallelScanSYCLBase {
private:
template <typename FunctorWrapper>
sycl::event sycl_direct_launch(const FunctorWrapper& functor_wrapper,
sycl::event memcpy_event) const {
sycl::event memcpy_event) {
// Convenience references
const Kokkos::Experimental::SYCL& space = m_policy.space();
Kokkos::Experimental::Impl::SYCLInternal& instance =
Expand All @@ -133,43 +133,32 @@ class ParallelScanSYCLBase {

const auto size = m_policy.end() - m_policy.begin();

// FIXME_SYCL optimize
constexpr size_t wgroup_size = 128;
auto n_wgroups = (size + wgroup_size - 1) / wgroup_size;
auto global_mem = m_scratch_space;
pointer_type group_results = global_mem + n_wgroups * wgroup_size;

auto scratch_flags = static_cast<sycl::device_ptr<unsigned int>>(
instance.scratch_flags(sizeof(unsigned int)));

// Initialize global memory
auto initialize_global_memory = q.submit([&](sycl::handler& cgh) {
auto begin = m_policy.begin();
const auto begin = m_policy.begin();

// Store subgroup totals
const auto min_subgroup_size =
q.get_device()
.template get_info<sycl::info::device::sub_group_sizes>()
.front();
sycl::local_accessor<value_type> local_mem(
sycl::range<1>((wgroup_size + min_subgroup_size - 1) /
min_subgroup_size),
cgh);
sycl::local_accessor<unsigned int> num_teams_done(1, cgh);

cgh.depends_on(memcpy_event);
// Initialize global memory
auto scan_lambda_factory =
[&](sycl::local_accessor<value_type> local_mem,
sycl::local_accessor<unsigned int> num_teams_done,
sycl::device_ptr<value_type> global_mem_,
sycl::device_ptr<value_type> group_results_) {
auto lambda = [=](sycl::nd_item<1> item) {
auto global_mem = global_mem_;
auto group_results = group_results_;

cgh.parallel_for(
sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
[=](sycl::nd_item<1> item) {
const CombinedFunctorReducer<
FunctorType, typename Analysis::Reducer>& functor_reducer =
functor_wrapper.get_functor();
const FunctorType& functor = functor_reducer.get_functor();
const typename Analysis::Reducer& reducer =
functor_reducer.get_reducer();

const index_type local_id = item.get_local_linear_id();
const auto n_wgroups = item.get_group_range()[0];
const int wgroup_size = item.get_local_range()[0];

const int local_id = item.get_local_linear_id();
const index_type global_id = item.get_global_linear_id();

// Initialize local memory
Expand Down Expand Up @@ -210,8 +199,9 @@ class ParallelScanSYCLBase {
local_value = group_results[id];
else
reducer.init(&local_value);
workgroup_scan<>(item, reducer, local_mem, local_value,
std::min(n_wgroups - offset, wgroup_size));
workgroup_scan<>(
item, reducer, local_mem, local_value,
std::min<index_type>(n_wgroups - offset, wgroup_size));
if (id < static_cast<index_type>(n_wgroups)) {
reducer.join(&local_value, &total);
group_results[id] = local_value;
Expand All @@ -223,23 +213,83 @@ class ParallelScanSYCLBase {
item.barrier(sycl::access::fence_space::global_space);
}
}
});
};
return lambda;
};

size_t wgroup_size;
size_t n_wgroups;
sycl::device_ptr<value_type> global_mem;
sycl::device_ptr<value_type> group_results;

auto perform_work_group_scans = q.submit([&](sycl::handler& cgh) {
sycl::local_accessor<unsigned int> num_teams_done(1, cgh);

auto dummy_scan_lambda =
scan_lambda_factory({1, cgh}, num_teams_done, nullptr, nullptr);

static sycl::kernel kernel = [&] {
sycl::kernel_id functor_kernel_id =
sycl::get_kernel_id<decltype(dummy_scan_lambda)>();
auto kernel_bundle =
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
q.get_context(), std::vector{functor_kernel_id});
return kernel_bundle.get_kernel(functor_kernel_id);
}();
auto multiple = kernel.get_info<sycl::info::kernel_device_specific::
preferred_work_group_size_multiple>(
q.get_device());
auto max =
kernel.get_info<sycl::info::kernel_device_specific::work_group_size>(
q.get_device());

wgroup_size = static_cast<size_t>(max / multiple) * multiple;
n_wgroups = (size + wgroup_size - 1) / wgroup_size;

// Compute the total amount of memory we will need.
// We need to allocate memory for the whole range (rounded towards the
// next multiple of the workgroup size) and for one element per workgroup
// that will contain the sum of the previous workgroups totals.
// FIXME_SYCL consider only storing one value per block and recreate
// initial results in the end before doing the final pass
global_mem =
static_cast<sycl::device_ptr<value_type>>(instance.scratch_space(
n_wgroups * (wgroup_size + 1) * sizeof(value_type)));
m_scratch_space = global_mem;

group_results = global_mem + n_wgroups * wgroup_size;

// Store subgroup totals in local space
const auto min_subgroup_size =
q.get_device()
.template get_info<sycl::info::device::sub_group_sizes>()
.front();
sycl::local_accessor<value_type> local_mem(
sycl::range<1>((wgroup_size + min_subgroup_size - 1) /
min_subgroup_size),
cgh);

cgh.depends_on(memcpy_event);

auto scan_lambda = scan_lambda_factory(local_mem, num_teams_done,
global_mem, group_results);
cgh.parallel_for(sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
scan_lambda);
});

// Write results to global memory
auto update_global_results = q.submit([&](sycl::handler& cgh) {
auto global_mem = m_scratch_space;
auto result_ptr_device_accessible = m_result_ptr_device_accessible;
// The compiler failed with CL_INVALID_ARG_VALUE if using m_result_ptr
// directly.
auto result_ptr = m_result_ptr_device_accessible ? m_result_ptr : nullptr;
auto begin = m_policy.begin();

cgh.depends_on(initialize_global_memory);
cgh.depends_on(perform_work_group_scans);

cgh.parallel_for(
sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
[=](sycl::nd_item<1> item) {
auto global_mem_copy = global_mem;
const index_type global_id = item.get_global_linear_id();
const CombinedFunctorReducer<
FunctorType, typename Analysis::Reducer>& functor_reducer =
Expand All @@ -258,7 +308,7 @@ class ParallelScanSYCLBase {
else
functor(WorkTag(), global_id + begin, update, true);

global_mem[global_id] = update;
global_mem_copy[global_id] = update;
if (global_id == size - 1 && result_ptr_device_accessible)
*result_ptr = update;
}
Expand All @@ -274,21 +324,7 @@ class ParallelScanSYCLBase {
void impl_execute(const PostFunctor& post_functor) {
if (m_policy.begin() == m_policy.end()) return;

auto& instance = *m_policy.space().impl_internal_space_instance();
const std::size_t len = m_policy.end() - m_policy.begin();

// Compute the total amount of memory we will need.
// We need to allocate memory for the whole range (rounded towards the next
// multiple of the wqorkgroup size) and for one element per workgroup that
// will contain the sum of the previous workgroups totals.
size_t wgroup_size = 128;
size_t n_wgroups = (len + wgroup_size - 1) / wgroup_size;
size_t total_memory = n_wgroups * (wgroup_size + 1) * sizeof(value_type);

// FIXME_SYCL consider only storing one value per block and recreate initial
// results in the end before doing the final pass
m_scratch_space = static_cast<sycl::device_ptr<value_type>>(
instance.scratch_space(total_memory));
auto& instance = *m_policy.space().impl_internal_space_instance();

Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem&
indirectKernelMem = instance.get_indirect_kernel_mem();
Expand Down

0 comments on commit bdaa12c

Please sign in to comment.