From e625bf013767577392c3b3e28fa1d8e46112f3fb Mon Sep 17 00:00:00 2001 From: Adam Fidel <110841220+adamfidel@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:06:06 -0500 Subject: [PATCH] Single-pass scan kernel template (#1320) Introduce a kernel template for inclusive_scan that performs a single-pass over the input data Co-authored-by: Sergey Kopienko Co-authored-by: Dmitriy Sobolev --- .../oneapi/dpl/experimental/kernel_templates | 4 +- .../dpl/experimental/kt/single_pass_scan.h | 436 ++++++++++++++++++ .../pstl/hetero/dpcpp/parallel_backend_sycl.h | 41 +- test/kt/CMakeLists.txt | 49 ++ test/kt/single_pass_scan.cpp | 230 +++++++++ test/support/test_config.h | 5 + 6 files changed, 747 insertions(+), 18 deletions(-) create mode 100644 include/oneapi/dpl/experimental/kt/single_pass_scan.h create mode 100644 test/kt/single_pass_scan.cpp diff --git a/include/oneapi/dpl/experimental/kernel_templates b/include/oneapi/dpl/experimental/kernel_templates index 6498adf6dd..6c7cb63428 100644 --- a/include/oneapi/dpl/experimental/kernel_templates +++ b/include/oneapi/dpl/experimental/kernel_templates @@ -15,7 +15,9 @@ #include "kt/kernel_param.h" #if __has_include() -# include "kt/esimd_radix_sort.h" +# include "kt/esimd_radix_sort.h" #endif +#include "kt/single_pass_scan.h" + #endif // _ONEDPL_KERNEL_TEMPLATES diff --git a/include/oneapi/dpl/experimental/kt/single_pass_scan.h b/include/oneapi/dpl/experimental/kt/single_pass_scan.h new file mode 100644 index 0000000000..cafffd6493 --- /dev/null +++ b/include/oneapi/dpl/experimental/kt/single_pass_scan.h @@ -0,0 +1,436 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file incorporates work covered by the following copyright and permission +// notice: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef _ONEDPL_KT_SINGLE_PASS_SCAN_H +#define _ONEDPL_KT_SINGLE_PASS_SCAN_H + +#include "../../pstl/hetero/dpcpp/sycl_defs.h" +#include "../../pstl/hetero/dpcpp/unseq_backend_sycl.h" +#include "../../pstl/hetero/dpcpp/parallel_backend_sycl.h" +#include "../../pstl/hetero/dpcpp/execution_sycl_defs.h" +#include "../../pstl/utils.h" + +#include +#include +#include +#include + +namespace oneapi::dpl::experimental::kt +{ + +namespace gpu +{ + +namespace __impl +{ + +template +class __lookback_init_kernel; + +template +class __lookback_kernel; + +static constexpr int SUBGROUP_SIZE = 32; + +template +struct __scan_status_flag +{ + using _FlagStorageType = uint32_t; + using _AtomicFlagT = sycl::atomic_ref<_FlagStorageType, sycl::memory_order::acq_rel, sycl::memory_scope::device, + sycl::access::address_space::global_space>; + using _AtomicValueT = sycl::atomic_ref<_T, sycl::memory_order::acq_rel, sycl::memory_scope::device, + sycl::access::address_space::global_space>; + + static constexpr _FlagStorageType __initialized_status = 0; + static constexpr _FlagStorageType __partial_status = 1; + static constexpr _FlagStorageType __full_status = 2; + static constexpr _FlagStorageType __oob_status = 3; + + static constexpr int __padding = SUBGROUP_SIZE; + + __scan_status_flag(_FlagStorageType* __flags, _T* __full_vals, _T* __partial_vals, const std::uint32_t __tile_id) + : __tile_id(__tile_id), __flags_begin(__flags), __full_vals_begin(__full_vals), + __partial_vals_begin(__partial_vals), __atomic_flag(*(__flags + __tile_id + __padding)), + __atomic_partial_value(*(__partial_vals + __tile_id + __padding)), + __atomic_full_value(*(__full_vals + __tile_id + __padding)) + { + } + + void + set_partial(const _T __val) + { + __atomic_partial_value.store(__val, sycl::memory_order::release); + __atomic_flag.store(__partial_status, sycl::memory_order::release); + } + + void + set_full(const _T __val) + { + __atomic_full_value.store(__val, sycl::memory_order::release); + __atomic_flag.store(__full_status, sycl::memory_order::release); + } + + template + _T + cooperative_lookback(const _Subgroup& __subgroup, _BinaryOp __binary_op) + { + _T __running = oneapi::dpl::unseq_backend::__known_identity<_BinaryOp, _T>; + auto __local_id = __subgroup.get_local_id(); + + for (int __tile = static_cast(__tile_id) - 1; __tile >= 0; __tile -= SUBGROUP_SIZE) + { + _AtomicFlagT __tile_flag_atomic(*(__flags_begin + __tile + __padding - __local_id)); + _T __tile_flag = __initialized_status; + + // Load flag from a previous tile based on my local id. + // Spin until every work-item in this subgroup reads a valid status + do + { + __tile_flag = __tile_flag_atomic.load(sycl::memory_order::acquire); + } while (!sycl::all_of_group(__subgroup, __tile_flag != __initialized_status)); + + bool __is_full = __tile_flag == __full_status; + auto __is_full_ballot = sycl::ext::oneapi::group_ballot(__subgroup, __is_full); + std::uint32_t __is_full_ballot_bits{}; + __is_full_ballot.extract_bits(__is_full_ballot_bits); + + _AtomicValueT __tile_value_atomic( + *((__is_full ? __full_vals_begin : __partial_vals_begin) + __tile + __padding - __local_id)); + _T __tile_val = __tile_value_atomic.load(sycl::memory_order::acquire); + + auto __lowest_item_with_full = sycl::ctz(__is_full_ballot_bits); + _T __contribution = __local_id <= __lowest_item_with_full + ? __tile_val + : oneapi::dpl::unseq_backend::__known_identity<_BinaryOp, _T>; + + // Running reduction of all of the partial results from the tiles found, as well as the full contribution from the closest tile (if any) + __running = __binary_op(__running, sycl::reduce_over_group(__subgroup, __contribution, __binary_op)); + + // If we found a full value, we can stop looking at previous tiles. Otherwise, + // keep going through tiles until we either find a full tile or we've completely + // recomputed the prefix using partial values + if (__is_full_ballot_bits) + break; + } + return __running; + } + + const uint32_t __tile_id; + _FlagStorageType* __flags_begin; + _T* __full_vals_begin; + _T* __partial_vals_begin; + _AtomicFlagT __atomic_flag; + _AtomicValueT __atomic_partial_value; + _AtomicValueT __atomic_full_value; +}; + +template +struct __lookback_init_submitter; + +template +struct __lookback_init_submitter<_FlagType, _Type, _BinaryOp, + oneapi::dpl::__par_backend_hetero::__internal::__optional_kernel_name<_Name...>> +{ + template + sycl::event + operator()(sycl::queue __q, _StatusFlags&& __status_flags, _PartialValues&& __partial_values, + std::size_t __status_flags_size, std::uint16_t __status_flag_padding) const + { + return __q.submit([&](sycl::handler& __hdl) { + __hdl.parallel_for<_Name...>(sycl::range<1>{__status_flags_size}, [=](const sycl::item<1>& __item) { + auto __id = __item.get_linear_id(); + __status_flags[__id] = + __id < __status_flag_padding ? _FlagType::__oob_status : _FlagType::__initialized_status; + __partial_values[__id] = oneapi::dpl::unseq_backend::__known_identity<_BinaryOp, _Type>; + }); + }); + } +}; + +template +struct __lookback_submitter; + +template +struct __lookback_kernel_func +{ + using _FlagStorageType = typename _FlagType::_FlagStorageType; + static constexpr std::uint32_t __elems_in_tile = __workgroup_size * __data_per_workitem; + + _InRng __in_rng; + _OutRng __out_rng; + _BinaryOp __binary_op; + std::size_t __n; + _StatusFlags __status_flags; + std::size_t __status_flags_size; + _StatusValues __status_vals_full; + _StatusValues __status_vals_partial; + std::size_t __current_num_items; + _TileVals __tile_vals; + + [[sycl::reqd_sub_group_size(SUBGROUP_SIZE)]] void + operator()(const sycl::nd_item<1>& __item) const + { + auto __group = __item.get_group(); + auto __subgroup = __item.get_sub_group(); + auto __local_id = __item.get_local_id(0); + + std::uint32_t __tile_id = 0; + + // Obtain unique ID for this work-group that will be used in decoupled lookback + if (__group.leader()) + { + sycl::atomic_ref<_FlagStorageType, sycl::memory_order::relaxed, sycl::memory_scope::device, + sycl::access::address_space::global_space> + __idx_atomic(__status_flags[__status_flags_size - 1]); + __tile_id = __idx_atomic.fetch_add(1); + } + + __tile_id = sycl::group_broadcast(__group, __tile_id, 0); + + std::size_t __current_offset = static_cast(__tile_id) * __elems_in_tile; + auto __out_begin = __out_rng.begin() + __current_offset; + + if (__current_offset >= __n) + return; + + // Global load into local + auto __wg_current_offset = (__tile_id * __elems_in_tile); + auto __wg_next_offset = ((__tile_id + 1) * __elems_in_tile); + auto __wg_local_memory_size = __elems_in_tile; + + if (__wg_next_offset > __n) + __wg_local_memory_size = __n - __wg_current_offset; + + if (__wg_next_offset <= __n) + { + _ONEDPL_PRAGMA_UNROLL + for (std::uint32_t __i = 0; __i < __data_per_workitem; ++__i) + { + __tile_vals[__local_id + __workgroup_size * __i] = + __in_rng[__wg_current_offset + __local_id + __workgroup_size * __i]; + } + } + else + { + _ONEDPL_PRAGMA_UNROLL + for (std::uint32_t __i = 0; __i < __data_per_workitem; ++__i) + { + if (__wg_current_offset + __local_id + __workgroup_size * __i < __n) + { + __tile_vals[__local_id + __workgroup_size * __i] = + __in_rng[__wg_current_offset + __local_id + __workgroup_size * __i]; + } + } + } + + auto __tile_vals_ptr = __dpl_sycl::__get_accessor_ptr(__tile_vals); + _Type __local_reduction = + sycl::joint_reduce(__group, __tile_vals_ptr, __tile_vals_ptr + __wg_local_memory_size, __binary_op); + _Type __prev_tile_reduction{}; + + // The first sub-group will query the previous tiles to find a prefix + if (__subgroup.get_group_id() == 0) + { + _FlagType __flag(__status_flags, __status_vals_full, __status_vals_partial, __tile_id); + + if (__subgroup.get_local_id() == 0) + { + __flag.set_partial(__local_reduction); + } + + __prev_tile_reduction = __flag.cooperative_lookback(__subgroup, __binary_op); + + if (__subgroup.get_local_id() == 0) + { + __flag.set_full(__binary_op(__prev_tile_reduction, __local_reduction)); + } + } + + __prev_tile_reduction = sycl::group_broadcast(__group, __prev_tile_reduction, 0); + + sycl::joint_inclusive_scan(__group, __tile_vals_ptr, __tile_vals_ptr + __wg_local_memory_size, __out_begin, + __binary_op, __prev_tile_reduction); + } +}; + +template +struct __lookback_submitter<__data_per_workitem, __workgroup_size, _Type, _FlagType, + oneapi::dpl::__par_backend_hetero::__internal::__optional_kernel_name<_Name...>> +{ + + template + sycl::event + operator()(sycl::queue __q, sycl::event __prev_event, _InRng&& __in_rng, _OutRng&& __out_rng, _BinaryOp __binary_op, + std::size_t __n, _StatusFlags&& __status_flags, std::size_t __status_flags_size, + _StatusValues&& __status_vals_full, _StatusValues&& __status_vals_partial, + std::size_t __current_num_items) const + { + using _LocalAccessorType = sycl::local_accessor<_Type, 1>; + using _KernelFunc = + __lookback_kernel_func<__data_per_workitem, __workgroup_size, _Type, _FlagType, std::decay_t<_InRng>, + std::decay_t<_OutRng>, std::decay_t<_BinaryOp>, std::decay_t<_StatusFlags>, + std::decay_t<_StatusValues>, std::decay_t<_LocalAccessorType>>; + + static constexpr std::uint32_t __elems_in_tile = __workgroup_size * __data_per_workitem; + + return __q.submit([&](sycl::handler& __hdl) { + auto __tile_vals = _LocalAccessorType(sycl::range<1>{__elems_in_tile}, __hdl); + __hdl.depends_on(__prev_event); + + oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng); + __hdl.parallel_for<_Name...>(sycl::nd_range<1>(__current_num_items, __workgroup_size), + _KernelFunc{__in_rng, __out_rng, __binary_op, __n, __status_flags, + __status_flags_size, __status_vals_full, __status_vals_partial, + __current_num_items, __tile_vals}); + }); + } +}; + +template +sycl::event +__single_pass_scan(sycl::queue __queue, _InRange&& __in_rng, _OutRange&& __out_rng, _BinaryOp __binary_op, _KernelParam) +{ + using _Type = oneapi::dpl::__internal::__value_t<_InRange>; + using _FlagType = __scan_status_flag<_Type>; + using _FlagStorageType = typename _FlagType::_FlagStorageType; + + using _KernelName = typename _KernelParam::kernel_name; + using _LookbackInitKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __lookback_init_kernel<_KernelName, _Type, _BinaryOp>>; + using _LookbackKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __lookback_kernel<_KernelName, _Type, _BinaryOp>>; + + const std::size_t __n = __in_rng.size(); + + if (__n == 0) + return sycl::event{}; + + static_assert(_Inclusive, "Single-pass scan only available for inclusive scan"); + static_assert(oneapi::dpl::unseq_backend::__has_known_identity<_BinaryOp, _Type>::value, + "Only binary operators with known identity values are supported"); + + assert("This device does not support 64-bit atomics" && + (sizeof(_Type) < 64 || __queue.get_device().has(sycl::aspect::atomic64))); + + // Next power of 2 greater than or equal to __n + auto __n_uniform = ::oneapi::dpl::__internal::__dpl_bit_ceil(__n); + + // Perform a single-work group scan if the input is small + if (oneapi::dpl::__par_backend_hetero::__group_scan_fits_in_slm<_Type>(__queue, __n, __n_uniform)) + { + return oneapi::dpl::__par_backend_hetero::__parallel_transform_scan_single_group( + oneapi::dpl::__internal::__device_backend_tag{}, + oneapi::dpl::execution::__dpl::make_device_policy(__queue), + std::forward<_InRange>(__in_rng), std::forward<_OutRange>(__out_rng), __n, + oneapi::dpl::__internal::__no_op{}, unseq_backend::__no_init_value<_Type>{}, __binary_op, std::true_type{}); + } + + constexpr std::size_t __workgroup_size = _KernelParam::workgroup_size; + constexpr std::size_t __data_per_workitem = _KernelParam::data_per_workitem; + + // Avoid non_uniform n by padding up to a multiple of workgroup_size + std::size_t __elems_in_tile = __workgroup_size * __data_per_workitem; + std::size_t __num_wgs = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __elems_in_tile); + + constexpr int __status_flag_padding = SUBGROUP_SIZE; + std::size_t __status_flags_size = __num_wgs + 1 + __status_flag_padding; + + std::size_t __mem_align_pad = sizeof(_Type); + std::size_t __status_flags_bytes = __status_flags_size * sizeof(_FlagStorageType); + std::size_t __status_vals_full_offset_bytes = __status_flags_size * sizeof(_Type); + std::size_t __status_vals_partial_offset_bytes = __status_flags_size * sizeof(_Type); + std::size_t __mem_bytes = + __status_flags_bytes + __status_vals_full_offset_bytes + __status_vals_partial_offset_bytes + __mem_align_pad; + + std::byte* __device_mem = reinterpret_cast(sycl::malloc_device(__mem_bytes, __queue)); + if (!__device_mem) + throw std::bad_alloc(); + + _FlagStorageType* __status_flags = reinterpret_cast<_FlagStorageType*>(__device_mem); + std::size_t __remainder = __mem_bytes - __status_flags_bytes; + void* __vals_base_ptr = reinterpret_cast(__device_mem + __status_flags_bytes); + void* __vals_aligned_ptr = + std::align(std::alignment_of_v<_Type>, __status_vals_full_offset_bytes, __vals_base_ptr, __remainder); + _Type* __status_vals_full = reinterpret_cast<_Type*>(__vals_aligned_ptr); + _Type* __status_vals_partial = + reinterpret_cast<_Type*>(__status_vals_full + __status_vals_full_offset_bytes / sizeof(_Type)); + + auto __fill_event = __lookback_init_submitter<_FlagType, _Type, _BinaryOp, _LookbackInitKernel>{}( + __queue, __status_flags, __status_vals_partial, __status_flags_size, __status_flag_padding); + + std::size_t __current_num_wgs = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __elems_in_tile); + std::size_t __current_num_items = __current_num_wgs * __workgroup_size; + + auto __prev_event = + __lookback_submitter<__data_per_workitem, __workgroup_size, _Type, _FlagType, _LookbackKernel>{}( + __queue, __fill_event, __in_rng, __out_rng, __binary_op, __n, __status_flags, __status_flags_size, + __status_vals_full, __status_vals_partial, __current_num_items); + + // TODO: Currently, the following portion of code makes this entire function synchronous. + // Ideally, we should be able to use the asynchronous free below, but we have found that doing + // so introduces a large unexplainable slowdown. Once this slowdown has been identified and corrected, + // we should replace this code with the asynchronous version below. + if (0) + { + return __queue.submit([=](sycl::handler& __hdl) { + __hdl.depends_on(__prev_event); + __hdl.host_task([=]() { sycl::free(__device_mem, __queue); }); + }); + } + else + { + __prev_event.wait(); + sycl::free(__device_mem, __queue); + return __prev_event; + } +} + +} // namespace __impl + +template +sycl::event +inclusive_scan(sycl::queue __queue, _InRng&& __in_rng, _OutRng&& __out_rng, _BinaryOp __binary_op, + _KernelParam __param = {}) +{ + auto __in_view = oneapi::dpl::__ranges::views::all(std::forward<_InRng>(__in_rng)); + auto __out_view = oneapi::dpl::__ranges::views::all(std::forward<_OutRng>(__out_rng)); + + return __impl::__single_pass_scan(__queue, std::move(__in_view), std::move(__out_view), __binary_op, __param); +} + +template +sycl::event +inclusive_scan(sycl::queue __queue, _InIterator __in_begin, _InIterator __in_end, _OutIterator __out_begin, + _BinaryOp __binary_op, _KernelParam __param = {}) +{ + auto __n = __in_end - __in_begin; + + auto __keep1 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _InIterator>(); + auto __buf1 = __keep1(__in_begin, __in_end); + auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _OutIterator>(); + auto __buf2 = __keep2(__out_begin, __out_begin + __n); + + return __impl::__single_pass_scan(__queue, __buf1.all_view(), __buf2.all_view(), __binary_op, __param); +} + +} // namespace gpu + +} // namespace oneapi::dpl::experimental::kt + +#endif /* _ONEDPL_KT_SINGLE_PASS_SCAN_H */ diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 60f49d4e58..b006eae051 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -691,20 +691,20 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend return __parallel_transform_scan_static_single_group_submitter< _Inclusive::value, __num_elems_per_item, __wg_size, /* _IsFullGroup= */ true, - oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< - __scan_single_wg_kernel<::std::integral_constant<::std::uint16_t, __wg_size>, - ::std::integral_constant<::std::uint16_t, __num_elems_per_item>, - /* _IsFullGroup= */ std::true_type, _Inclusive, _CustomName>>>()( + oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__scan_single_wg_kernel< + ::std::integral_constant<::std::uint16_t, __wg_size>, + ::std::integral_constant<::std::uint16_t, __num_elems_per_item>, _BinaryOperation, + /* _IsFullGroup= */ std::true_type, _Inclusive, _CustomName>>>()( ::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op); else return __parallel_transform_scan_static_single_group_submitter< _Inclusive::value, __num_elems_per_item, __wg_size, /* _IsFullGroup= */ false, - oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< - __scan_single_wg_kernel<::std::integral_constant<::std::uint16_t, __wg_size>, - ::std::integral_constant<::std::uint16_t, __num_elems_per_item>, - /* _IsFullGroup= */ ::std::false_type, _Inclusive, _CustomName>>>()( + oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__scan_single_wg_kernel< + ::std::integral_constant<::std::uint16_t, __wg_size>, + ::std::integral_constant<::std::uint16_t, __num_elems_per_item>, _BinaryOperation, + /* _IsFullGroup= */ ::std::false_type, _Inclusive, _CustomName>>>()( ::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op); }; @@ -734,7 +734,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend else { using _DynamicGroupScanKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< - __par_backend_hetero::__scan_single_wg_dynamic_kernel<_CustomName>>; + __par_backend_hetero::__scan_single_wg_dynamic_kernel<_BinaryOperation, _CustomName>>; return __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive::value, _DynamicGroupScanKernel>()( ::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), @@ -759,6 +759,20 @@ __parallel_transform_scan_base(oneapi::dpl::__internal::__device_backend_tag, _E __binary_op, __init, __local_scan, __group_scan, __global_scan); } +template +bool +__group_scan_fits_in_slm(const sycl::queue& __queue, ::std::size_t __n, ::std::size_t __n_uniform) +{ + constexpr int __single_group_upper_limit = 16384; + + // Pessimistically only use half of the memory to take into account memory used by compiled kernel + const ::std::size_t __max_slm_size = + __queue.get_device().template get_info() / 2; + const auto __req_slm_size = sizeof(_Type) * __n_uniform; + + return (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size); +} + template auto @@ -773,17 +787,10 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen if ((__n_uniform & (__n_uniform - 1)) != 0) __n_uniform = oneapi::dpl::__internal::__dpl_bit_floor(__n) << 1; - // Pessimistically only use half of the memory to take into account memory used by compiled kernel - const ::std::size_t __max_slm_size = - __exec.queue().get_device().template get_info() / 2; - const auto __req_slm_size = sizeof(_Type) * __n_uniform; - - constexpr int __single_group_upper_limit = 16384; - constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value; if constexpr (__can_use_group_scan) { - if (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size) + if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform)) { return __parallel_transform_scan_single_group( __backend_tag, std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng), diff --git a/test/kt/CMakeLists.txt b/test/kt/CMakeLists.txt index 76571b2fdd..3e31210881 100644 --- a/test/kt/CMakeLists.txt +++ b/test/kt/CMakeLists.txt @@ -13,6 +13,7 @@ ##===----------------------------------------------------------------------===## option(ONEDPL_TEST_ENABLE_KT_ESIMD "Enable ESIMD-based kernel template tests") +option(ONEDPL_TEST_ENABLE_KT_SYCL "Enable SYCL-based kernel template tests") function(_generate_test _target_name _test_path) add_executable(${_target_name} EXCLUDE_FROM_ALL ${_test_path}) @@ -128,3 +129,51 @@ if (ONEDPL_TEST_ENABLE_KT_ESIMD) # Pin some cases to track them, e.g. because they fail _generate_esimd_sort_test("esimd_radix_sort" "256" "32" "double" "" 1000) # segfault endif() + +function (_generate_gpu_scan_test _data_per_work_item _work_group_size _type) + + if ((NOT TARGET build-scan-kt-tests) AND (NOT TARGET run-scan-kt-tests)) + add_custom_target(build-scan-kt-tests COMMENT "Build all scan kernel template tests") + add_custom_target(run-scan-kt-tests + COMMAND "${CMAKE_CTEST_COMMAND}" -R "^run-scan-kt-tests$" --output-on-failure --no-label-summary + DEPENDS build-scan-kt-tests + COMMENT "Build and run all scan kernel template tests") + endif() + + string(REPLACE "_t" "" _type_short ${_type}) + set(_target_name "single_pass_scan_dpwi${_data_per_work_item}_wgs${_work_group_size}_${_type_short}") + set(_test_path "single_pass_scan.cpp") + + #_generate_test_randomly(${_target_name} ${_test_path} ${_probability_permille}) + _generate_test(${_target_name} ${_test_path}) + if(TARGET ${_target_name}) + add_dependencies(build-scan-kt-tests ${_target_name}) + add_dependencies(run-scan-kt-tests ${_target_name}) + + target_compile_definitions(${_target_name} PRIVATE TEST_DATA_PER_WORK_ITEM=${_data_per_work_item}) + target_compile_definitions(${_target_name} PRIVATE TEST_WORK_GROUP_SIZE=${_work_group_size}) + target_compile_definitions(${_target_name} PRIVATE TEST_TYPE=${_type}) + endif() + +endfunction() + +function(_generate_gpu_scan_tests) + set(_data_per_work_item_all "1" "2" "4" "8" "16" "32") + set(_work_group_size_all "64" "128" "256" "512" "1024") + set(_type_all "uint32_t" "int32_t" "float" "int64_t" "uint64_t" "double") + + foreach (_data_per_work_item ${_data_per_work_item_all}) + foreach (_work_group_size ${_work_group_size_all}) + foreach (_type ${_type_all}) + _generate_gpu_scan_test(${_data_per_work_item} ${_work_group_size} ${_type}) + endforeach() + endforeach() + endforeach() + + _generate_test("single_pass_scan" "single_pass_scan.cpp") + target_compile_definitions("single_pass_scan" PRIVATE TEST_DATA_PER_WORK_ITEM=8 TEST_WORK_GROUP_SIZE=256 TEST_TYPE=uint32_t) +endfunction() + +if (ONEDPL_TEST_ENABLE_KT_SYCL) + _generate_gpu_scan_tests() +endif() diff --git a/test/kt/single_pass_scan.cpp b/test/kt/single_pass_scan.cpp new file mode 100644 index 0000000000..eace470c32 --- /dev/null +++ b/test/kt/single_pass_scan.cpp @@ -0,0 +1,230 @@ +// -*- C++ -*- +//===-- single_pass_scan.cpp ----------------------------------------------===// +// +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file incorporates work covered by the following copyright and permission +// notice: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// +#include "../support/test_config.h" + +#include + +#if LOG_TEST_INFO +# include +#endif + +#if _ENABLE_RANGES_TESTING +# include +#endif + +#include "../support/utils.h" +#include "../support/sycl_alloc_utils.h" + +#include "esimd_radix_sort_utils.h" + +#include +#include +#include +#include +#include +#include + +inline const std::vector scan_sizes = { + 1, 6, 16, 43, 256, 316, 2048, + 5072, 8192, 14001, 1 << 14, (1 << 14) + 1, 50000, 67543, + 100'000, 1 << 17, 179'581, 250'000, 1 << 18, (1 << 18) + 1, 500'000, + 888'235, 1'000'000, 1 << 20, 10'000'000}; + +template +auto +generate_scan_data(T* input, std::size_t size, std::uint32_t seed) +{ + // Integer numbers are generated even for floating point types in order to avoid rounding errors, + // and simplify the final check + using substitute_t = std::conditional_t, std::int64_t, std::uint64_t>; + + const substitute_t start = std::is_signed_v ? -10 : 0; + const substitute_t end = 10; + + std::default_random_engine gen{seed}; + std::uniform_int_distribution dist(start, end); + std::generate(input, input + size, [&] { return dist(gen); }); + + if constexpr (std::is_same_v, BinOp>) + { + std::size_t custom_item_count = size < 5 ? size : 5; + std::fill(input + custom_item_count, input + size, 1); + std::replace(input, input + custom_item_count, 0, 2); + std::shuffle(input, input + size, gen); + } +} + +#if _ENABLE_RANGES_TESTING +template +void +test_all_view(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param) +{ +# if LOG_TEST_INFO + std::cout << "\ttest_all_view(" << size << ") : " << TypeInfo().name() << std::endl; +# endif + std::vector input(size); + generate_scan_data(input.data(), size, 42); + std::vector ref(input); + sycl::buffer buf_out(input.size()); + + std::inclusive_scan(std::begin(ref), std::end(ref), std::begin(ref), bin_op); + { + sycl::buffer buf(input.data(), input.size()); + oneapi::dpl::experimental::ranges::all_view view(buf); + oneapi::dpl::experimental::ranges::all_view view_out(buf_out); + oneapi::dpl::experimental::kt::gpu::inclusive_scan(q, view, view_out, bin_op, param).wait(); + } + + auto acc = buf_out.get_host_access(); + + std::string msg = "wrong results with all_view, n: " + std::to_string(size); + EXPECT_EQ_RANGES(ref, acc, msg.c_str()); +} + +template +void +test_buffer(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param) +{ +# if LOG_TEST_INFO + std::cout << "\ttest_buffer(" << size << ") : " << TypeInfo().name() << std::endl; +# endif + std::vector input(size); + generate_scan_data(input.data(), size, 42); + std::vector ref(input); + sycl::buffer buf_out(input.size()); + + std::inclusive_scan(std::begin(ref), std::end(ref), std::begin(ref), bin_op); + { + sycl::buffer buf(input.data(), input.size()); + oneapi::dpl::experimental::kt::gpu::inclusive_scan(q, buf, buf_out, bin_op, param).wait(); + } + + auto acc = buf_out.get_host_access(); + + std::string msg = "wrong results with buffer, n: " + std::to_string(size); + EXPECT_EQ_RANGES(ref, acc, msg.c_str()); +} +#endif + +template +void +test_usm(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param) +{ +#if LOG_TEST_INFO + std::cout << "\t\ttest_usm<" << TypeInfo().name() << ", " << USMAllocPresentation().name<_alloc_type>() << ">(" + << size << ");" << std::endl; +#endif + std::vector expected(size); + generate_scan_data(expected.data(), size, 42); + + TestUtils::usm_data_transfer<_alloc_type, T> dt_input(q, expected.begin(), expected.end()); + TestUtils::usm_data_transfer<_alloc_type, T> dt_output(q, size); + + std::inclusive_scan(expected.begin(), expected.end(), expected.begin(), bin_op); + + oneapi::dpl::experimental::kt::gpu::inclusive_scan(q, dt_input.get_data(), dt_input.get_data() + size, + dt_output.get_data(), bin_op, param) + .wait(); + + std::vector actual(size); + dt_output.retrieve_data(actual.begin()); + + std::string msg = "wrong results with USM, n: " + std::to_string(size); + EXPECT_EQ_N(expected.begin(), actual.begin(), size, msg.c_str()); +} + +template +void +test_sycl_iterators(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param) +{ +#if LOG_TEST_INFO + std::cout << "\t\ttest_sycl_iterators<" << TypeInfo().name() << ">(" << size << ");" << std::endl; +#endif + std::vector input(size); + std::vector output(size); + generate_scan_data(input.data(), size, 42); + std::vector ref(input); + std::inclusive_scan(std::begin(ref), std::end(ref), std::begin(ref), bin_op); + { + sycl::buffer buf(input.data(), input.size()); + sycl::buffer buf_out(output.data(), output.size()); + oneapi::dpl::experimental::kt::gpu::inclusive_scan(q, oneapi::dpl::begin(buf), oneapi::dpl::end(buf), + oneapi::dpl::begin(buf_out), bin_op, param) + .wait(); + } + + std::string msg = "wrong results with oneapi::dpl::begin/end, n: " + std::to_string(size); + EXPECT_EQ_RANGES(ref, output, msg.c_str()); +} + +template +void +test_general_cases(sycl::queue q, std::size_t size, BinOp bin_op, KernelParam param) +{ + test_usm(q, size, bin_op, TestUtils::get_new_kernel_params<0>(param)); + test_usm(q, size, bin_op, TestUtils::get_new_kernel_params<1>(param)); + test_sycl_iterators(q, size, bin_op, TestUtils::get_new_kernel_params<2>(param)); +#if _ENABLE_RANGES_TESTING + test_all_view(q, size, bin_op, TestUtils::get_new_kernel_params<3>(param)); + test_buffer(q, size, bin_op, TestUtils::get_new_kernel_params<4>(param)); +#endif +} + +template +void +test_all_cases(sycl::queue q, std::size_t size, KernelParam param) +{ + test_general_cases(q, size, std::plus{}, TestUtils::get_new_kernel_params<0>(param)); +#if _PSTL_GROUP_REDUCTION_MULT_INT64_BROKEN + static constexpr bool int64_mult_broken = std::is_integral_v && (sizeof(T) == 8); +#else + static constexpr bool int64_mult_broken = 0; +#endif + if constexpr (!int64_mult_broken) + { + test_general_cases(q, size, std::multiplies{}, TestUtils::get_new_kernel_params<1>(param)); + } +} + +int +main() +{ +#if LOG_TEST_INFO + std::cout << "TEST_DATA_PER_WORK_ITEM : " << TEST_DATA_PER_WORK_ITEM << "\n" + << "TEST_WORK_GROUP_SIZE : " << TEST_WORK_GROUP_SIZE << "\n" + << "TEST_TYPE : " << TypeInfo().name() << std::endl; +#endif + + constexpr oneapi::dpl::experimental::kt::kernel_param params; + auto q = TestUtils::get_test_queue(); + bool run_test = can_run_test(q, params); + + if (run_test) + { + + try + { + for (auto size : scan_sizes) + test_all_cases(q, size, params); + } + catch (const std::exception& exc) + { + std::cerr << "Exception: " << exc.what() << std::endl; + return EXIT_FAILURE; + } + } + + return TestUtils::done(run_test); +} diff --git a/test/support/test_config.h b/test/support/test_config.h index 37c6d43340..4d211ad0b8 100644 --- a/test/support/test_config.h +++ b/test/support/test_config.h @@ -164,4 +164,9 @@ # define _PSTL_ICPX_TEST_RED_BY_SEG_BROKEN_64BIT_TYPES 1 #endif +// Group reduction produces wrong results with multiplication of 64-bit for certain driver versions +// TODO: When a driver fix is provided to resolve this issue, consider altering this macro or checking the driver version at runtime +// of the underlying sycl::device to determine whether to include or exclude 64-bit type tests. +#define _PSTL_GROUP_REDUCTION_MULT_INT64_BROKEN 1 + #endif // _TEST_CONFIG_H