From cef2ad0a0925fed13c12b6c49ea7752fd42aef1d Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Fri, 8 Mar 2024 10:55:07 +0000 Subject: [PATCH 01/22] subgroup level batch_interleaved and packed bluestein algorithm working --- .clang-tidy | 1 + src/portfft/committed_descriptor_impl.hpp | 96 ++- src/portfft/common/bluestein.hpp | 78 +++ src/portfft/common/global.hpp | 9 +- src/portfft/common/host_dft.hpp | 110 ++++ src/portfft/common/subgroup.hpp | 293 +++++++++ src/portfft/common/transpose.hpp | 23 + .../dispatcher/subgroup_dispatcher.hpp | 576 ++++++++---------- .../dispatcher/workgroup_dispatcher.hpp | 15 +- src/portfft/enums.hpp | 3 + src/portfft/specialization_constant.hpp | 3 + src/portfft/utils.hpp | 33 +- test/unit_test/fft_test_utils.hpp | 9 + test/unit_test/instantiate_fft_tests.hpp | 6 + 14 files changed, 883 insertions(+), 372 deletions(-) create mode 100644 src/portfft/common/bluestein.hpp create mode 100644 src/portfft/common/host_dft.hpp diff --git a/.clang-tidy b/.clang-tidy index 0b3225d5..1fe9e338 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -14,6 +14,7 @@ Checks: > performance-*, -performance-avoid-endl, readability-*, + -readability-magic-numbers, -readability-function-cognitive-complexity, -readability-identifier-length, -readability-named-parameter, diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index d985e816..8182e08d 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -148,18 +148,26 @@ class committed_descriptor_impl { std::vector transpose_kernels; std::shared_ptr factors_and_scan; detail::level level; + // The size of DFT transform which will be computed for the given dimension std::size_t length; + // The committed length for the particular dimension, will be different from length in the case of bluestein and + // radar fft algorithms + std::size_t committed_length; Idx used_sg_size; Idx num_batches_in_l2; Idx num_factors; + bool is_prime; dimension_struct(std::vector forward_kernels, std::vector backward_kernels, - detail::level level, std::size_t length, Idx used_sg_size) + detail::level level, std::size_t length, std::size_t committed_length, Idx used_sg_size, + bool is_prime) : forward_kernels(std::move(forward_kernels)), backward_kernels(std::move(backward_kernels)), level(level), length(length), - used_sg_size(used_sg_size) {} + committed_length(committed_length), + used_sg_size(used_sg_size), + is_prime(is_prime) {} }; std::vector dimensions; @@ -203,12 +211,12 @@ class committed_descriptor_impl { * set of kernels that need to be JIT compiled. * * @tparam SubgroupSize size of the subgroup - * @param kernel_num the consecutive number of the kernel to prepare + * @param fft_size The size of the dft transform * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, * vector of kernel ids, factors */ template - std::tuple prepare_implementation(std::size_t kernel_num) { + std::tuple prepare_implementation(IdxGlobal fft_size) { PORTFFT_LOG_FUNCTION_ENTRY(); // TODO: check and support all the parameter values if constexpr (Domain != domain::COMPLEX) { @@ -217,11 +225,12 @@ class committed_descriptor_impl { std::vector ids; std::vector factors; - IdxGlobal fft_size = static_cast(params.lengths[kernel_num]); if (detail::fits_in_wi(fft_size)) { ids = detail::get_ids(); PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size); - return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; + return {detail::level::WORKITEM, + static_cast(fft_size), + {{detail::level::WORKITEM, ids, {static_cast(fft_size)}}}}; } if (detail::fits_in_sg(fft_size, SubgroupSize)) { Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); @@ -232,14 +241,11 @@ class committed_descriptor_impl { factors.push_back(factor_sg); ids = detail::get_ids(); PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg); - return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}}; + return {detail::level::SUBGROUP, static_cast(fft_size), {{detail::level::SUBGROUP, ids, factors}}}; } IdxGlobal n_idx_global = detail::factorize(fft_size); if (detail::can_cast_safely(n_idx_global) && detail::can_cast_safely(fft_size / n_idx_global)) { - if (n_idx_global == 1) { - throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported"); - } Idx n = static_cast(n_idx_global); Idx m = static_cast(fft_size / n_idx_global); Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize); @@ -265,7 +271,8 @@ class committed_descriptor_impl { ids = detail::get_ids(); PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n, " factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m); - return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}}; + return { + detail::level::WORKGROUP, static_cast(fft_size), {{detail::level::WORKGROUP, ids, factors}}}; } } PORTFFT_LOG_TRACE("Preparing global impl"); @@ -288,10 +295,8 @@ class committed_descriptor_impl { num_scalars_in_local_mem(detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg, batch_interleaved_layout ? layout::BATCH_INTERLEAVED : layout::PACKED); - std::size_t store_modifiers = batch_interleaved_layout ? input_scalars : 0; std::size_t twiddle_scalars = 2 * static_cast(factor_size); - return (sizeof(Scalar) * (input_scalars + store_modifiers + twiddle_scalars)) < - static_cast(local_memory_size); + return (sizeof(Scalar) * (input_scalars + twiddle_scalars)) < static_cast(local_memory_size); } return false; }(); @@ -308,8 +313,15 @@ class committed_descriptor_impl { } return false; }; - detail::factorize_input(fft_size, check_and_select_target_level); - return {detail::level::GLOBAL, param_vec}; + bool encountered_large_prime = detail::factorize_input(fft_size, check_and_select_target_level); + std::cout << "encountered_large_prime = " << encountered_large_prime << std::endl; + if (encountered_large_prime) { + std::cout << "I HAVE ENCOUNTERED A LARGE PRIME, fft size = " << fft_size << std::endl; + IdxGlobal padded_size = detail::get_bluestein_padded_size(fft_size); + std::cout << "PADDED FFT SIZE = " << padded_size << std::endl; + return prepare_implementation(padded_size); + } + return {detail::level::GLOBAL, static_cast(fft_size), param_vec}; } /** @@ -463,6 +475,7 @@ class committed_descriptor_impl { const bool is_global = top_level == detail::level::GLOBAL; const bool is_final_factor = counter == (prepared_vec.size() - 1); const bool is_final_dim = dimension_num == (params.lengths.size() - 1); + const Idx factor_size = std::accumulate(factors.begin(), factors.end(), 1, std::multiplies()); const bool is_backward = compute_direction == direction::BACKWARD; if (is_multi_dim && is_global) { throw unsupported_configuration("multidimensional global transforms are not supported."); @@ -484,22 +497,20 @@ class committed_descriptor_impl { IdxGlobal backward_distance{}; if (is_global) { - length = std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies()); - - remaining_factors_prod /= length; + remaining_factors_prod /= factor_size; forward_stride = remaining_factors_prod; backward_stride = remaining_factors_prod; - forward_distance = is_final_factor ? length : 1; - backward_distance = is_final_factor ? length : 1; + forward_distance = is_final_factor ? factor_size : 1; + backward_distance = is_final_factor ? factor_size : 1; } else { - length = static_cast(params.lengths[dimension_num]); + Idx committed_length = static_cast(params.lengths[dimension_num]); forward_stride = static_cast(params.forward_strides[dimension_num]); backward_stride = static_cast(params.backward_strides[dimension_num]); if (is_multi_dim) { if (is_final_dim) { - forward_distance = length; - backward_distance = length; + forward_distance = committed_length; + backward_distance = committed_length; } else { forward_distance = 1; backward_distance = 1; @@ -517,14 +528,28 @@ class committed_descriptor_impl { auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); - set_spec_constants(top_level, in_bundle, length, factors, detail::elementwise_multiply::NOT_APPLIED, + if (factor_size != static_cast(params.lengths[dimension_num])) { + in_bundle.template set_specialization_constant(detail::fft_algorithm::BLUESTEIN); + in_bundle.template set_specialization_constant( + static_cast(params.lengths[dimension_num])); + } else { + // TODO: This needs to change in the case of global + in_bundle.template set_specialization_constant( + detail::fft_algorithm::COOLEY_TUKEY); + in_bundle.template set_specialization_constant( + static_cast(params.lengths[dimension_num])); + } + + set_spec_constants(top_level, in_bundle, factor_size, factors, detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale, level, conjugate_on_load, conjugate_on_store, scale_factor, input_stride, output_stride, input_distance, output_distance, static_cast(counter), static_cast(prepared_vec.size())); try { PORTFFT_LOG_TRACE("Building kernel bundle with subgroup size", SubgroupSize); - result.emplace_back(sycl::build(in_bundle), factors, params.lengths[dimension_num], SubgroupSize, - PORTFFT_SGS_IN_WG, std::shared_ptr(), level); + result.emplace_back( + sycl::build(in_bundle), factors, + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())), + SubgroupSize, PORTFFT_SGS_IN_WG, std::shared_ptr(), level); PORTFFT_LOG_TRACE("Kernel bundle build complete."); } catch (std::exception& e) { PORTFFT_LOG_WARNING("Build for subgroup size", SubgroupSize, "failed with message:\n", e.what()); @@ -549,7 +574,8 @@ class committed_descriptor_impl { dimension_struct build_w_spec_const(std::size_t dimension_num) { PORTFFT_LOG_FUNCTION_ENTRY(); if (std::count(supported_sg_sizes.begin(), supported_sg_sizes.end(), SubgroupSize)) { - auto [top_level, prepared_vec] = prepare_implementation(dimension_num); + auto [top_level, fft_size, prepared_vec] = + prepare_implementation(static_cast(params.lengths[dimension_num])); bool is_compatible = true; for (auto [level, ids, factors] : prepared_vec) { is_compatible = is_compatible && sycl::is_compatible(ids, dev); @@ -558,14 +584,24 @@ class committed_descriptor_impl { } } + std::cout << "FFT SIZE = " << fft_size << std::endl; + if (top_level == detail::level::SUBGROUP) { + std::cout << "PADDED FFT SIZE LEVEL is SUBGROUP " << std::endl; + } + // exit(-10); if (is_compatible) { auto forward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::FORWARD, dimension_num); auto backward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::BACKWARD, dimension_num); if (forward_kernels.has_value() && backward_kernels.has_value()) { - return {forward_kernels.value(), backward_kernels.value(), top_level, params.lengths[dimension_num], - SubgroupSize}; + return {forward_kernels.value(), + backward_kernels.value(), + top_level, + fft_size, + params.lengths[dimension_num], + SubgroupSize, + fft_size != params.lengths[dimension_num]}; } } } diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp new file mode 100644 index 00000000..eda5c786 --- /dev/null +++ b/src/portfft/common/bluestein.hpp @@ -0,0 +1,78 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_BLUESTEIN_HPP +#define PORTFFT_COMMON_BLUESTEIN_HPP + +#include "portfft/common/host_dft.hpp" +#include "portfft/defines.hpp" + +#include +#include + +namespace portfft { +namespace detail { +/** + * Utility function to get the dft transform of the chirp signal + * @tparam T Scalar Type + * @param ptr Host Pointer containing the load/store modifiers. + * @param committed_size original problem size + * @param dimension_size padded size + */ +template +void populate_fft_chirp_signal(T* ptr, std::size_t committed_size, std::size_t dimension_size) { + std::cout << "committed_size = " << committed_size << " padded size = " << dimension_size << std::endl; + using complex_t = std::complex; + std::vector chirp_signal(dimension_size, 0); + std::vector chirp_fft(dimension_size, 0); + for (std::size_t i = 0; i < committed_size; i++) { + double theta = M_PI * static_cast(i * i) / static_cast(committed_size); + chirp_signal[i] = complex_t(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + } + std::size_t num_zeros = dimension_size - 2 * committed_size + 1; + for (std::size_t i = 1; i < committed_size; i++) { + chirp_signal[committed_size + num_zeros + i - 1] = chirp_signal[committed_size - i]; + } + host_cooley_tukey(chirp_signal.data(), chirp_fft.data(), dimension_size); + std::memcpy(ptr, reinterpret_cast(chirp_fft.data()), 2 * dimension_size * sizeof(T)); +} + +/** + * Populates input modifiers required for bluestein + * @tparam T Scalar Type + * @param ptr Host Pointer containing the load/store modifiers. + * @param committed_size committed problem length + * @param dimension_size padded dft length + */ +template +void populate_bluestein_input_modifiers(T* ptr, std::size_t committed_size, std::size_t dimension_size) { + std::cout << "committed_size = " << committed_size << " padded size = " << dimension_size << std::endl; + using complex_t = std::complex; + std::vector scratch(dimension_size, 0); + for (std::size_t i = 0; i < committed_size; i++) { + double theta = -M_PI * static_cast(i * i) / static_cast(committed_size); + scratch[i] = complex_t(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + } + std::memcpy(ptr, reinterpret_cast(scratch.data()), 2 * dimension_size * sizeof(T)); +} +} // namespace detail +} // namespace portfft + +#endif diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 727d0bed..bb225a66 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -152,13 +152,12 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc workitem_impl(input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, batch_size, global_data, kh, static_cast(nullptr), - store_modifier_data, static_cast(nullptr), store_modifier_loc); + store_modifier_data); } else if (level == detail::level::SUBGROUP) { subgroup_impl(input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data, kh, - static_cast(nullptr), store_modifier_data, - static_cast(nullptr), store_modifier_loc); + static_cast(nullptr), store_modifier_data); } else if (level == detail::level::WORKGROUP) { workgroup_impl(input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, @@ -317,10 +316,10 @@ std::vector compute_level( std::size_t local_mem_for_store_modifier = [&]() -> std::size_t { if (factor_id < total_factors - 1) { if (kd_struct.level == detail::level::WORKITEM || kd_struct.level == detail::level::WORKGROUP) { - return 1; + return 0; } if (kd_struct.level == detail::level::SUBGROUP) { - return kd_struct.local_mem_required; + return 0; } } return std::size_t(1); diff --git a/src/portfft/common/host_dft.hpp b/src/portfft/common/host_dft.hpp new file mode 100644 index 00000000..a61e361c --- /dev/null +++ b/src/portfft/common/host_dft.hpp @@ -0,0 +1,110 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_HOST_DFT_HPP +#define PORTFFT_COMMON_HOST_DFT_HPP + +#include "portfft/common/helpers.hpp" +#include "portfft/defines.hpp" +#include + +namespace portfft { +namespace detail { + +/** + * Host Naive DFT. Works OOP only + * @tparam T Scalar Type + * @param input input pointer + * @param output output pointer + * @param fft_size fft size + */ +template +void host_naive_dft(std::complex* input, std::complex* output, std::size_t fft_size) { + using complex_t = std::complex; + for (std::size_t i = 0; i < fft_size; i++) { + complex_t temp = complex_t(0, 0); + for (std::size_t j = 0; j < fft_size; j++) { + complex_t multiplier = + complex_t(static_cast(std::cos((-2 * M_PI * static_cast(i * j)) / static_cast(fft_size))), + static_cast(std::sin((-2 * M_PI * static_cast(i * j)) / static_cast(fft_size)))); + temp += input[j] * multiplier; + } + output[i] = temp; + } +} + +/** + * Host implementation of the cooley tukey algorithm. Handles prime values using the naive implementation + * @tparam T Scalar type for std::complex + * @param input pointer of type std::complex containing the input values + * @param output otuput pointer of type std::complex containing the output values + * @param fft_size DFT size + */ +template +void host_cooley_tukey(std::complex* input, std::complex* output, std::size_t fft_size) { + std::size_t n = detail::factorize(fft_size); + if (n == 1 || fft_size <= 8) { + host_naive_dft(input, output, fft_size); + return; + } + + std::size_t m = fft_size / n; + std::size_t scratch_size = n > m ? n : m; + std::vector> scratch_space(scratch_size); + std::vector> scratch_space2(scratch_size); + std::vector> output_buffer(fft_size); + + for (std::size_t i = 0; i < m; i++) { + for (std::size_t j = 0; j < n; j++) { + scratch_space[j] = input[j * m + i]; + } + host_cooley_tukey(scratch_space.data(), scratch_space2.data(), n); + for (std::size_t j = 0; j < n; j++) { + output[j * m + i] = scratch_space2[j]; + } + } + + for (std::size_t i = 0; i < n; i++) { + for (std::size_t j = 0; j < m; j++) { + double theta = -2 * M_PI * static_cast(i * j) / static_cast(n * m); + output[i * m + j] *= std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + } + } + + for (std::size_t i = 0; i < n; i++) { + for (std::size_t j = 0; j < m; j++) { + scratch_space[j] = output[i * m + j]; + } + host_cooley_tukey(scratch_space.data(), scratch_space2.data(), m); + for (std::size_t j = 0; j < m; j++) { + output_buffer[i * m + j] = scratch_space2[j]; + } + } + + for (std::size_t i = 0; i < fft_size; i++) { + std::size_t j = i / n; + std::size_t k = i % n; + output[i] = output_buffer[k * m + j]; + } +} +} // namespace detail +} // namespace portfft + +#endif diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index 0dbdd019..825b4f9c 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -24,6 +24,10 @@ #include #include "helpers.hpp" +#include "portfft/common/logging.hpp" +#include "portfft/common/memory_views.hpp" +#include "portfft/common/transfers.hpp" +#include "portfft/common/transpose.hpp" #include "portfft/defines.hpp" #include "portfft/enums.hpp" #include "twiddle.hpp" @@ -307,6 +311,295 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles sg_twiddles[(k + factor_wi) * factor_sg + n] = twiddle.imag(); } +template +PORTFFT_INLINE void subgroup_impl_local2global_strided_copy(T* global_ptr, LocView& loc_view, + std::array strides_global, + std::array strides_local, + IdxGlobal offset_global, Idx offset_local, + std::array copy_strides, + detail::global_data_struct<1> global_data, + detail::transfer_direction direction) { + detail::md_view global_md_view{global_ptr, strides_global, offset_global}; + detail::md_view local_md_view{loc_view, strides_local, offset_local}; + if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + copy_group(global_data, global_md_view, local_md_view, copy_strides); + } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { + copy_group(global_data, local_md_view, global_md_view, copy_strides); + } +} + +template +PORTFFT_INLINE void subgroup_impl_local2global_strided_copy( + T* global_ptr, T* global_imag_ptr, LocView& loc_view, std::array strides_global, + std::array strides_local, IdxGlobal offset_global, Idx local_offset, Idx local_imag_offset, + std::array copy_strides, detail::global_data_struct<1> global_data, + detail::transfer_direction direction) { + detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; + detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; + detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; + detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; + if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + copy_group(global_data, global_md_real_view, local_md_real_view, copy_strides); + copy_group(global_data, global_md_imag_view, local_md_imag_view, copy_strides); + } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { + copy_group(global_data, local_md_real_view, global_md_real_view, copy_strides); + copy_group(global_data, local_md_imag_view, global_md_imag_view, copy_strides); + } +} + +template +PORTFFT_INLINE void subgroup_impl_local_private_copy( + PtrView& ptr_view, PtrView& ptr_imag_view, T* priv, + std::array, 2> ptr_view_strides_offsets, + std::array, 2> priv_view_strides_offsets, + std::array, 2> ptr_imag_view_strides_offsets, + std::array, 2> priv_imag_view_strides_offsets, Idx num_elements_to_copy, + detail::global_data_struct<1> global_data, detail::transfer_direction direction) { + detail::strided_view ptr_strided_real_view{ptr_view, std::get<0>(ptr_view_strides_offsets), + std::get<1>(ptr_view_strides_offsets)}; + detail::strided_view ptr_strided_imag_view{ptr_imag_view, std::get<0>(ptr_imag_view_strides_offsets), + std::get<1>(ptr_imag_view_strides_offsets)}; + detail::strided_view priv_strided_real_view{priv, std::get<0>(priv_view_strides_offsets), + std::get<1>(priv_view_strides_offsets)}; + detail::strided_view priv_strided_imag_view{priv, std::get<0>(priv_imag_view_strides_offsets), + std::get<1>(priv_imag_view_strides_offsets)}; + if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { + copy_wi(global_data, ptr_strided_real_view, priv_strided_real_view, num_elements_to_copy); + copy_wi(global_data, ptr_strided_imag_view, priv_strided_imag_view, num_elements_to_copy); + } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || + direction == detail::transfer_direction::PRIVATE_TO_GLOBAL) { + copy_wi(global_data, priv_strided_real_view, ptr_strided_real_view, num_elements_to_copy); + copy_wi(global_data, priv_strided_imag_view, ptr_strided_imag_view, num_elements_to_copy); + } +} + +template +PORTFFT_INLINE void subgroup_impl_local_private_copy( + PtrView& ptr_view, T* priv, std::array, 2> ptr_view_strides_offsets, + Idx num_elements_to_copy, detail::global_data_struct<1> global_data, detail::transfer_direction direction) { + detail::strided_view ptr_strided_view{ptr_view, std::get<0>(ptr_view_strides_offsets), + std::get<1>(ptr_view_strides_offsets)}; + if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { + copy_wi<2>(global_data, ptr_strided_view, priv, num_elements_to_copy); + } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || + direction == detail::transfer_direction::PRIVATE_TO_GLOBAL) { + copy_wi<2>(global_data, priv, ptr_strided_view, num_elements_to_copy); + } +} + +template +PORTFFT_INLINE void subgroup_impl_bluestein_localglobal_packed_copy( + TIn* global_ptr, TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, + IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, sycl::sub_group& sg, + complex_storage storage, detail::transfer_direction direction, detail::global_data_struct<1>& global_data) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg; i++) { + if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global2local( + global_data, global_ptr, loc_view, 2 * committed_size, + static_cast(2 * i * committed_size) + global_ptr_offset, 2 * i * fft_size + loc_offset); + } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { + local2global(global_data, loc_view, global_ptr, 2 * committed_size, + 2 * i * fft_size + loc_offset, + global_ptr_offset + 2 * i * committed_size); + } + } + } else { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg; i++) { + if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global2local( + global_data, global_ptr, loc_view, committed_size, + static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset); + global2local( + global_data, global_imag_ptr, loc_view, committed_size, + static_cast(i * committed_size) + global_ptr_offset, + i * fft_size + loc_offset + local_imag_offset); + } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { + local2global(global_data, loc_view, global_ptr, committed_size, + i * fft_size + loc_offset, + global_ptr_offset + i * committed_size); + local2global(global_data, loc_view, global_imag_ptr, committed_size, + i * fft_size + loc_offset + local_imag_offset, + global_ptr_offset + i * committed_size); + } + } + } + + sycl::group_barrier(sg); +} + +template +PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementwise_multiply apply_load_modifier, + detail::elementwise_multiply apply_store_modifier, + detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, + detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data, + const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor, + IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, + sycl::sub_group& sg) { + using vec2_t = sycl::vec; + vec2_t modifier_vec; + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + detail::conjugate_inplace(priv, factor_wi); + } + if (apply_load_modifier == detail::elementwise_multiply::APPLIED) { + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } + } + sg_dft(priv, sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); + + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + detail::conjugate_inplace(priv, factor_wi); + } + + if (apply_store_modifier == detail::elementwise_multiply::APPLIED) { + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } + } + + if (scale_factor_applied == detail::apply_scale_factor::APPLIED) { + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + priv[2 * j] *= scale_factor; + priv[2 * j + 1] *= scale_factor; + } + } +} + +template +PORTFFT_INLINE void sg_bluestein_batch_interleaved(T* priv, T* priv_scratch, LocView& loc_view, const T* load_modifier, + const T* store_modifier, LocTwiddlesView& twiddles_loc, + detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, + detail::apply_scale_factor scale_applied, T scale_factor, + Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, + complex_storage storage, bool wi_working, Idx local_imag_offset, + Idx max_num_batches_local_mem, Idx fft_idx_in_local, + sycl::sub_group& sg, detail::global_data_struct<1>& global_data) { + sg_dft_compute( + priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, + conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, load_modifier, + store_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + + PORTFFT_UNROLL + for (Idx i = 0; i < 2 * factor_wi; i++) { + priv[i] = (priv[i] / (static_cast(factor_sg * factor_wi))); + } + + if (wi_working) { + // Store back to local memory only + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + subgroup_impl_local_private_copy<2, Idx>( + loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, + factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + } else { + subgroup_impl_local_private_copy<2, 1, Idx>( + loc_view, loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, + {{{2}, {0}}}, + {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, + {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + } + } + + sycl::group_barrier(sg); + if (wi_working) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; + subgroup_impl_local_private_copy<1, Idx>( + loc_view, priv, + {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, factor_wi, + global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + } else { + subgroup_impl_local_private_copy<2, 1, Idx>( + loc_view, loc_view, priv, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, + {{{2}, {0}}}, + {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, + {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + } + } + + auto conjugate_on_output = conjugate_on_store == detail::complex_conjugate::APPLIED + ? detail::complex_conjugate::NOT_APPLIED + : detail::complex_conjugate::APPLIED; + + sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, + conjugate_on_output, scale_applied, static_cast(nullptr), load_modifier, + twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); +} + +template +void sg_bluestein(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& loc_twiddles, const T* load_modifier, + const T* store_modifier, detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, + T scale_factor, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, + bool wi_working, Idx loc_offset_store_view, Idx loc_offset_load_view, Idx local_imag_offset, + sycl::sub_group sg, detail::global_data_struct<1>& global_data) { + // for (Idx i = 0; i < 2 * factor_wi; i++) { + // priv[i] = 2; + // } + sg_dft_compute( + priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, + conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, load_modifier, + store_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + + PORTFFT_UNROLL + for (Idx i = 0; i < 2 * factor_wi; i++) { + priv[i] = (priv[i] / (static_cast(factor_sg * factor_wi))); + } + + if (wi_working) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + subgroup_impl_local_private_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_offset_store_view}}}, factor_wi, + global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + } else { + detail::strided_view priv_real_view{priv, 2}; + detail::strided_view priv_imag_view{priv, 2, 1}; + detail::strided_view local_real_view{loc_view, factor_sg, loc_offset_store_view}; + detail::strided_view local_imag_view{loc_view, factor_sg, loc_offset_store_view + local_imag_offset}; + copy_wi(global_data, priv_real_view, local_real_view, factor_wi); + copy_wi(global_data, priv_imag_view, local_imag_view, factor_wi); + } + } + + sycl::group_barrier(sg); + + if (wi_working) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + subgroup_impl_local_private_copy<1, Idx>(loc_view, priv, {{{1}, {loc_offset_load_view}}}, factor_wi, global_data, + detail::transfer_direction::LOCAL_TO_PRIVATE); + } else { + subgroup_impl_local_private_copy<1, 1, Idx>(loc_view, loc_view, priv, {{{1}, {loc_offset_load_view}}}, + {{{2}, {0}}}, {{{1}, {loc_offset_load_view + local_imag_offset}}}, + {{{2}, {1}}}, factor_wi, global_data, + detail::transfer_direction::LOCAL_TO_PRIVATE); + } + } + + auto conjugate_on_output = conjugate_on_store == detail::complex_conjugate::APPLIED + ? detail::complex_conjugate::NOT_APPLIED + : detail::complex_conjugate::APPLIED; + + sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, + conjugate_on_output, scale_applied, static_cast(nullptr), load_modifier, + loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); +} + }; // namespace portfft #endif diff --git a/src/portfft/common/transpose.hpp b/src/portfft/common/transpose.hpp index 75775112..b496a739 100644 --- a/src/portfft/common/transpose.hpp +++ b/src/portfft/common/transpose.hpp @@ -98,6 +98,29 @@ PORTFFT_INLINE inline void generic_transpose(IdxGlobal N, IdxGlobal M, Idx tile_ } } } + +template +PORTFFT_INLINE void shuffle_transpose(T* priv, T* output, Idx lda, Idx ldb, sycl::sub_group sg) { + Idx sg_local_linear_id = static_cast(sg.get_local_linear_id()); + Idx id_of_thread_in_fft = sg_local_linear_id % lda; + Idx matrix_start_lane_id = (sg_local_linear_id - id_of_thread_in_fft) & (SubgroupSize - 1); + Idx lane_id_relative_to_start = id_of_thread_in_fft & (lda - 1); + + PORTFFT_UNROLL + for (Idx id_of_element_in_wi = 0; id_of_element_in_wi < ldb; id_of_element_in_wi++) { + Idx relative_target_lane_id = ((lane_id_relative_to_start + id_of_element_in_wi) & (ldb - 1)) * (lda / ldb) + + (lane_id_relative_to_start / ldb); + Idx target_lane_id = matrix_start_lane_id + relative_target_lane_id; + Idx store_address = (sg_local_linear_id + id_of_element_in_wi) & (ldb - 1); + Idx target_address = ((ldb - id_of_element_in_wi) + (sg_local_linear_id / (lda / ldb))) & (ldb - 1); + T& real_value = priv[2 * target_address]; + T& complex_value = priv[2 * target_address + 1]; + output[2 * store_address] = sycl::select_from_group(sg, real_value, static_cast(target_lane_id)); + output[2 * store_address + 1] = + sycl::select_from_group(sg, complex_value, static_cast(target_lane_id)); + } +} + } // namespace detail } // namespace portfft #endif diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index f3485de8..c02908ff 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -21,6 +21,7 @@ #ifndef PORTFFT_DISPATCHER_SUBGROUP_DISPATCHER_HPP #define PORTFFT_DISPATCHER_SUBGROUP_DISPATCHER_HPP +#include "portfft/common/bluestein.hpp" #include "portfft/common/helpers.hpp" #include "portfft/common/logging.hpp" #include "portfft/common/memory_views.hpp" @@ -79,15 +80,12 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su * @param twiddles pointer containing twiddles * @param load_modifier_data Pointer to the load modifier data in global Memory * @param store_modifier_data Pointer to the store modifier data in global Memory - * @param loc_load_modifier Pointer to load modifier data in local memory - * @param loc_store_modifier Pointer to store modifier data in local memory */ template PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, T* loc_twiddles, IdxGlobal n_transforms, const T* twiddles, global_data_struct<1> global_data, sycl::kernel_handler& kh, - const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr, - T* loc_load_modifier = nullptr, T* loc_store_modifier = nullptr) { + const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr) { const complex_storage storage = kh.get_specialization_constant(); const detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); @@ -107,6 +105,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag const IdxGlobal output_stride = kh.get_specialization_constant(); const IdxGlobal input_distance = kh.get_specialization_constant(); const IdxGlobal output_distance = kh.get_specialization_constant(); + const Idx committed_length = kh.get_specialization_constant(); + detail::fft_algorithm algorithm = kh.get_specialization_constant(); global_data.log_message_global(__func__, "entered", "FactorWI", factor_wi, "FactorSG", factor_sg, "n_transforms", n_transforms); @@ -143,8 +143,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag const bool is_input_batch_interleaved = input_stride == n_transforms && input_distance == 1; const bool is_output_batch_interleaved = output_stride == n_transforms && output_distance == 1; - const bool is_input_packed = input_stride == 1 && input_distance == fft_size; - const bool is_output_packed = output_stride == 1 && output_distance == fft_size; + const bool is_input_packed = input_stride == 1 && input_distance == committed_length; + const bool is_output_packed = output_stride == 1 && output_distance == committed_length; IdxGlobal id_of_fft_in_kernel; IdxGlobal n_ffts_in_kernel; @@ -158,8 +158,6 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag constexpr Idx BankLinesPerPad = 1; auto loc_view = detail::padded_view(loc, BankLinesPerPad); - auto loc_load_modifier_view = detail::padded_view(loc_load_modifier, BankLinesPerPad); - auto loc_store_modifier_view = detail::padded_view(loc_store_modifier, BankLinesPerPad); global_data.log_message_global(__func__, "loading sg twiddles from global to local memory"); global2local(global_data, twiddles, loc_twiddles, n_reals_per_fft); @@ -193,43 +191,23 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag Idx rounded_up_ffts_in_local = detail::round_up_to_multiple(num_batches_in_local_mem, n_ffts_per_sg); Idx local_imag_offset = factor_wi * factor_sg * max_num_batches_local_mem; - const bool store_directly_from_private = SubgroupSize == factor_sg && is_output_packed; - - if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading load multipliers from global to local memory"); - global2local(global_data, load_modifier_data, loc_load_modifier_view, - n_reals_per_fft * num_batches_in_local_mem, - i * n_reals_per_fft); - } - // TODO: Replace this with Async DMA where the hardware supports it. - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading store multipliers from global to local memory"); - global2local(global_data, store_modifier_data, loc_store_modifier_view, - n_reals_per_fft * num_batches_in_local_mem, - i * n_reals_per_fft); - } + const bool store_directly_from_private = + SubgroupSize == factor_sg && is_output_packed && algorithm == detail::fft_algorithm::COOLEY_TUKEY; global_data.log_message_global(__func__, "loading transposed data from global to local memory"); // load / store in a transposed manner if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::md_view input_view{input, std::array{2 * n_transforms, static_cast(1)}, 2 * i}; - detail::md_view local_md_view{loc_view, std::array{2 * max_num_batches_local_mem, 1}}; - copy_group(global_data, input_view, local_md_view, - std::array{fft_size, 2 * num_batches_in_local_mem}); + subgroup_impl_local2global_strided_copy( + const_cast(input), loc_view, {2 * n_transforms, static_cast(1)}, + {2 * max_num_batches_local_mem, 1}, 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data, + detail::transfer_direction::GLOBAL_TO_LOCAL); } else { - detail::md_view input_real_view{input, std::array{n_transforms, static_cast(1)}, i}; - detail::md_view input_imag_view{input_imag, std::array{n_transforms, static_cast(1)}, i}; - detail::md_view local_real_view{loc_view, std::array{max_num_batches_local_mem, 1}}; - detail::md_view local_imag_view{loc_view, std::array{max_num_batches_local_mem, 1}, local_imag_offset}; - global_data.log_message_global(__func__, "params", max_num_batches_local_mem, fft_size, - num_batches_in_local_mem); - global_data.log_message_global(__func__, "loading transposed real data from global to local memory"); - copy_group(global_data, input_real_view, local_real_view, - std::array{fft_size, num_batches_in_local_mem}); - global_data.log_message_global(__func__, "loading transposed imag data from global to local memory"); - copy_group(global_data, input_imag_view, local_imag_view, - std::array{fft_size, num_batches_in_local_mem}); + subgroup_impl_local2global_strided_copy( + const_cast(input), const_cast(input_imag), loc_view, {n_transforms, static_cast(1)}, + {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, + global_data, detail::transfer_direction::GLOBAL_TO_LOCAL); } + sycl::group_barrier(global_data.it.get_group()); global_data.log_dump_local("data loaded to local memory:", loc_view, n_reals_per_wi * factor_sg * max_num_batches_local_mem); @@ -242,74 +220,33 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working_inner) { global_data.log_message_global(__func__, "loading batch_interleaved data from local to private memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - Idx local_stride = max_num_batches_local_mem; const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; - const Idx local_offset = fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local; - detail::strided_view strided_local_view{loc_view, local_stride, local_offset}; - copy_wi<2>(global_data, strided_local_view, priv, factor_wi); + subgroup_impl_local_private_copy<1, Idx>( + loc_view, priv, + {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, + factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } else { - detail::strided_view local_real_view{loc_view, std::array{1, max_num_batches_local_mem}, - std::array{id_of_wi_in_fft * factor_wi, fft_idx_in_local}}; - detail::strided_view local_imag_view{ - loc_view, std::array{1, max_num_batches_local_mem}, - std::array{id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}; - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - copy_wi(global_data, local_real_view, priv_real_view, factor_wi); - copy_wi(global_data, local_imag_view, priv_imag_view, factor_wi); + subgroup_impl_local_private_copy<2, 1, Idx>( + loc_view, loc_view, priv, + {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, {{{2}, {0}}}, + {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, + {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } - if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - // Note: if using load modifier, this data need to be stored in the transposed fashion per batch to ensure - // low latency reads from shared memory, as this will result in much lesser bank conflicts. - // Tensor shape for load modifier in local memory = num_batches_in_local_mem x FactorWI x FactorSG - // TODO: change the above mentioned layout to the following tenshor shape: num_batches_in_local_mem x - // n_ffts_in_sg x FactorWI x FactorSG - global_data.log_message_global(__func__, "multiplying load modifier data"); - if (working_inner) { - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - Idx base_offset = fft_idx_in_local * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - multiply_complex(priv[2 * j], priv[2 * j + 1], loc_load_modifier_view[base_offset], - loc_load_modifier_view[base_offset + 1], priv[2 * j], priv[2 * j + 1]); - } - } - } - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } - sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } - if (working_inner) { - global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); - } - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - // No need to store the store modifier data in a transposed fashion as data after sg_dft is already transposed - // Tensor Shape for store modifier is num_batches_in_local_memory x FactorSG x FactorWI - global_data.log_message_global(__func__, "multiplying store modifier data"); - if (working_inner) { - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - sycl::vec modifier_priv; - Idx base_offset = fft_idx_in_local * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - // TODO: this leads to compilation error on AMD. Revert back to this once it is resolved - // modifier_priv.load(0, detail::get_local_multi_ptr(&loc_store_modifier_view[base_offset])); - modifier_priv[0] = loc_store_modifier_view[base_offset]; - modifier_priv[1] = loc_store_modifier_view[base_offset + 1]; - multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_priv[0], modifier_priv[1], priv[2 * j], - priv[2 * j + 1]); - } - } - } - if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { - PORTFFT_UNROLL - for (Idx idx = 0; idx < factor_wi; idx++) { - priv[2 * idx] *= scaling_factor; - priv[2 * idx + 1] *= scaling_factor; - } + IdxGlobal modifier_offset = + static_cast(n_reals_per_fft) * (i + static_cast(fft_idx_in_local + id_of_fft_in_sg)); + if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { + sg_dft_compute(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, + conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, + loc_twiddles, scaling_factor, modifier_offset, id_of_wi_in_fft, factor_sg, + factor_wi, global_data.sg); + } else { + sg_bluestein_batch_interleaved( + priv, wi_private_scratch, loc_view, load_modifier_data, store_modifier_data, loc_twiddles, + conjugate_on_load, conjugate_on_store, apply_scale_factor, scaling_factor, id_of_wi_in_fft, factor_sg, + factor_wi, storage, working_inner, local_imag_offset, max_num_batches_local_mem, fft_idx_in_local, + global_data.sg, global_data); } // Async DMA can start here for the next set of load/store modifiers. if (working_inner) { @@ -321,24 +258,24 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag __func__, "storing transposed data from private to packed global memory (SubgroupSize == FactorSG)"); // Store directly from registers for fully coalesced accesses if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view output_view{ - output, static_cast(factor_sg), - (i + static_cast(fft_idx_in_local)) * static_cast(n_reals_per_fft) + - static_cast(2 * id_of_wi_in_fft)}; - copy_wi<2>(global_data, priv, output_view, factor_wi); + subgroup_impl_local_private_copy<1, IdxGlobal>( + output, priv, + {{{static_cast(factor_sg)}, + {static_cast(i + static_cast(fft_idx_in_local)) * + static_cast(2 * fft_size) + + static_cast(2 * id_of_wi_in_fft)}}}, + factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_GLOBAL); } else { - detail::strided_view output_real_view{ - output, static_cast(factor_sg), - (i + static_cast(fft_idx_in_local)) * static_cast(fft_size) + - static_cast(id_of_wi_in_fft)}; - detail::strided_view output_imag_view{ - output_imag, static_cast(factor_sg), - (i + static_cast(fft_idx_in_local)) * static_cast(fft_size) + - static_cast(id_of_wi_in_fft)}; - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - copy_wi(global_data, priv_real_view, output_real_view, factor_wi); - copy_wi(global_data, priv_imag_view, output_imag_view, factor_wi); + subgroup_impl_local_private_copy<1, 1, IdxGlobal>( + output, output_imag, priv, + {{{static_cast(factor_sg)}, + {(i + static_cast(fft_idx_in_local)) * static_cast(fft_size) + + static_cast(id_of_wi_in_fft)}}}, + {{{2}, {0}}}, + {{{static_cast(factor_sg)}, + {(i + static_cast(fft_idx_in_local)) * static_cast(fft_size) + + static_cast(id_of_wi_in_fft)}}}, + {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_GLOBAL); } } } else { @@ -349,18 +286,16 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag "FactorSG or not packed output layout)"); // Store back to local memory only if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view strided_local_view{loc_view, std::array{factor_sg, max_num_batches_local_mem}, - std::array{2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}; - copy_wi<2>(global_data, priv, strided_local_view, factor_wi); + subgroup_impl_local_private_copy<2, Idx>( + loc_view, priv, + {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, factor_wi, + global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } else { - detail::strided_view local_real_view{loc_view, std::array{factor_sg, max_num_batches_local_mem}, - std::array{id_of_wi_in_fft, fft_idx_in_local}}; - detail::strided_view local_imag_view{loc_view, std::array{factor_sg, max_num_batches_local_mem}, - std::array{id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}; - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - copy_wi(global_data, priv_real_view, local_real_view, factor_wi); - copy_wi(global_data, priv_imag_view, local_imag_view, factor_wi); + subgroup_impl_local_private_copy<2, 1, Idx>( + loc_view, loc_view, priv, + {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, {{{2}, {0}}}, + {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, + {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } } } @@ -375,41 +310,29 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag "storing data from batch interleaved local memory to not batch interleaved " "global memory (SubgroupSize != FactorSG)"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - const std::array local_strides{max_num_batches_local_mem * 2, 2, 1}; - const std::array global_strides{output_stride * 2, output_distance * 2, 1}; - const std::array copy_lengths{fft_size, num_batches_in_local_mem, 2}; - detail::md_view local_md_view2{loc_view, local_strides}; - detail::md_view output_view{output, global_strides, i * output_distance * 2}; - copy_group(global_data, local_md_view2, output_view, copy_lengths); + subgroup_impl_local2global_strided_copy( + output, loc_view, {output_stride * 2, output_distance * 2, 1}, {max_num_batches_local_mem * 2, 2, 1}, + i * output_distance * 2, 0, {committed_length, num_batches_in_local_mem, 2}, global_data, + detail::transfer_direction::LOCAL_TO_GLOBAL); } else { - const std::array local_strides{max_num_batches_local_mem, 1}; - const std::array global_strides{output_stride, output_distance}; - - detail::md_view local_real_view{loc_view, local_strides}; - detail::md_view local_imag_view{loc_view, local_strides, local_imag_offset}; - detail::md_view output_real_view{output, global_strides, i * output_distance}; - detail::md_view output_imag_view{output_imag, global_strides, i * output_distance}; - std::array copy_lengths{fft_size, num_batches_in_local_mem}; - copy_group(global_data, local_real_view, output_real_view, copy_lengths); - copy_group(global_data, local_imag_view, output_imag_view, copy_lengths); + subgroup_impl_local2global_strided_copy( + output, output_imag, loc_view, {output_stride, output_distance}, {max_num_batches_local_mem, 1}, + i * output_distance, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data, + detail::transfer_direction::LOCAL_TO_GLOBAL); } } else { global_data.log_message_global( __func__, "storing data from batch interleaved local memory to batch interleaved global memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::md_view local_md_view2{loc_view, std::array{2 * max_num_batches_local_mem, 1}}; - detail::md_view output_view{output, std::array{2 * n_transforms, static_cast(1)}, 2 * i}; - copy_group(global_data, local_md_view2, output_view, - std::array{factor_wi * factor_sg, 2 * num_batches_in_local_mem}); + subgroup_impl_local2global_strided_copy( + output, loc_view, {2 * n_transforms, static_cast(1)}, {2 * max_num_batches_local_mem, 1}, + 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data, + detail::transfer_direction::LOCAL_TO_GLOBAL); } else { - detail::md_view local_real_view{loc_view, std::array{max_num_batches_local_mem, 1}}; - detail::md_view local_imag_view{loc_view, std::array{max_num_batches_local_mem, 1}, local_imag_offset}; - detail::md_view output_real_view{output, std::array{n_transforms, static_cast(1)}, i}; - detail::md_view output_imag_view{output_imag, std::array{n_transforms, static_cast(1)}, i}; - copy_group(global_data, local_real_view, output_real_view, - std::array{factor_wi * factor_sg, num_batches_in_local_mem}); - copy_group(global_data, local_imag_view, output_imag_view, - std::array{factor_wi * factor_sg, num_batches_in_local_mem}); + subgroup_impl_local2global_strided_copy( + output, output_imag, loc_view, {n_transforms, static_cast(1)}, + {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, + global_data, detail::transfer_direction::LOCAL_TO_GLOBAL); } } } @@ -422,130 +345,100 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag const Idx local_offset = subgroup_id * n_io_reals_per_sg; global_data.log_message_global(__func__, "loading non-transposed data from global to local memory"); - if (is_input_packed) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - global2local( - global_data, input, loc_view, n_ffts_worked_on_by_sg * n_reals_per_fft, - static_cast(n_reals_per_fft) * (i - static_cast(id_of_fft_in_sg)), - subgroup_id * n_reals_per_sg); + if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { + if (is_input_packed) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + global2local( + global_data, input, loc_view, n_ffts_worked_on_by_sg * n_reals_per_fft, + static_cast(n_reals_per_fft) * (i - static_cast(id_of_fft_in_sg)), + subgroup_id * n_reals_per_sg); + } else { + global2local( + global_data, input, loc_view, n_ffts_worked_on_by_sg * fft_size, + static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), + subgroup_id * n_cplx_per_sg); + global2local( + global_data, input_imag, loc_view, n_ffts_worked_on_by_sg * fft_size, + static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), + local_imag_offset + subgroup_id * n_cplx_per_sg); + } } else { - global2local( - global_data, input, loc_view, n_ffts_worked_on_by_sg * fft_size, - static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - subgroup_id * n_cplx_per_sg); - global2local( - global_data, input_imag, loc_view, n_ffts_worked_on_by_sg * fft_size, - static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - local_imag_offset + subgroup_id * n_cplx_per_sg); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + global_data.log_message_global(__func__, "storing data from unpacked global memory to local"); + subgroup_impl_local2global_strided_copy( + const_cast(input), loc_view, {input_distance * 2, input_stride * 2, 1}, + {committed_length * 2, 2, 1}, input_distance * 2 * (i - static_cast(id_of_fft_in_sg)), + local_offset, {n_ffts_worked_on_by_sg, committed_length, 2}, global_data, + detail::transfer_direction::GLOBAL_TO_LOCAL); + } else { + subgroup_impl_local2global_strided_copy( + const_cast(input), const_cast(input_imag), loc_view, {input_distance, input_stride}, + {committed_length, 1}, input_distance * (i - static_cast(id_of_fft_in_sg)), local_offset, + local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data, + detail::transfer_direction::GLOBAL_TO_LOCAL); + } } } else { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - const IdxGlobal global_input_offset = input_distance * 2 * (i - static_cast(id_of_fft_in_sg)); - std::array global_strides{input_distance * 2, input_stride * 2, 1}; - std::array local_strides{fft_size * 2, 2, 1}; - std::array copy_indices{n_ffts_worked_on_by_sg, fft_size, 2}; - detail::md_view global_input_view{input, global_strides, global_input_offset}; - detail::md_view local_input_view{loc_view, local_strides, local_offset}; - global_data.log_message_global(__func__, "storing data from unpacked global memory to local"); - copy_group(global_data, global_input_view, local_input_view, copy_indices); - } else { - const IdxGlobal global_input_offset = input_distance * (i - static_cast(id_of_fft_in_sg)); - std::array global_strides{input_distance, input_stride}; - std::array local_strides{fft_size, 1}; - std::array copy_indices{n_ffts_worked_on_by_sg, fft_size}; - - detail::md_view global_input_real_view{input, global_strides, global_input_offset}; - detail::md_view local_input_real_view{loc_view, local_strides, local_offset}; - detail::md_view global_input_imag_view{input_imag, global_strides, global_input_offset}; - detail::md_view local_input_imag_view{loc_view, local_strides, local_offset + local_imag_offset}; - global_data.log_message_global(__func__, "storing real data from unpacked global memory to local"); - copy_group(global_data, global_input_real_view, local_input_real_view, copy_indices); - global_data.log_message_global(__func__, "storing imaginary data from unpacked global memory to local"); - copy_group(global_data, global_input_imag_view, local_input_imag_view, copy_indices); + if (is_input_packed) { + auto global_ptr_offset = storage == complex_storage::INTERLEAVED_COMPLEX + ? 2 * committed_length * (i - static_cast(id_of_fft_in_sg)) + : committed_length * (i - static_cast(id_of_fft_in_sg)); + auto loc_view_offset = storage == complex_storage::INTERLEAVED_COMPLEX + ? 2 * factor_sg * factor_wi * subgroup_id + : factor_sg * factor_wi * subgroup_id; + auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; + subgroup_impl_bluestein_localglobal_packed_copy( + const_cast(input), const_cast(input_imag), loc_view, committed_length, factor_sg * factor_wi, + global_ptr_offset, loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, + detail::transfer_direction::GLOBAL_TO_LOCAL, global_data); } } - if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading load multipliers from global to local memory"); - global2local( - global_data, load_modifier_data, loc_load_modifier_view, n_ffts_worked_on_by_sg * n_reals_per_fft, - n_reals_per_fft * (i - id_of_fft_in_sg), subgroup_id * n_reals_per_sg); - } - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading store multipliers from global to local memory"); - global2local( - global_data, store_modifier_data, loc_store_modifier_view, n_ffts_worked_on_by_sg * n_reals_per_fft, - n_reals_per_fft * (i - id_of_fft_in_sg), subgroup_id * n_reals_per_sg); - } - sycl::group_barrier(global_data.sg); + + // sycl::group_barrier(global_data.sg); global_data.log_dump_local("data in local memory:", loc_view, n_reals_per_fft); if (working) { global_data.log_message_global(__func__, "loading non-transposed data from local to private memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::offset_view offset_local_view{loc_view, - subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi}; - copy_wi(global_data, offset_local_view, priv, n_reals_per_wi); + subgroup_impl_local_private_copy<1, Idx>( + loc_view, priv, {{{1}, {subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi}}}, factor_wi, + global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } else { - detail::offset_view local_real_view{loc_view, subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi}; - detail::offset_view local_imag_view{ - loc_view, subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}; - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - copy_wi(global_data, local_real_view, priv_real_view, factor_wi); - copy_wi(global_data, local_imag_view, priv_imag_view, factor_wi); + subgroup_impl_local_private_copy<1, 1, Idx>( + loc_view, loc_view, priv, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi}}}, + {{{2}, {0}}}, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}}}, + {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } sycl::group_barrier(global_data.sg); - if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - if (working) { - global_data.log_message_global(__func__, "Multiplying load modifier before sg_dft"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - Idx base_offset = static_cast(global_data.sg.get_group_id()) * n_ffts_per_sg + - id_of_fft_in_sg * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - multiply_complex(priv[2 * j], priv[2 * j + 1], loc_load_modifier_view[base_offset], - loc_load_modifier_view[base_offset + 1], priv[2 * j], priv[2 * j + 1]); - } - } - } - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } - sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } - if (working) { - global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); - } - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - if (working) { - global_data.log_message_global(__func__, "Multiplying store modifier before sg_dft"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - sycl::vec modifier_priv; - Idx base_offset = static_cast(global_data.it.get_sub_group().get_group_id()) * n_ffts_per_sg + - id_of_fft_in_sg * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - // modifier_priv.load(0, detail::get_local_multi_ptr(&loc_store_modifier_view[base_offset])); - modifier_priv[0] = loc_store_modifier_view[base_offset]; - modifier_priv[1] = loc_store_modifier_view[base_offset + 1]; - multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_priv[0], modifier_priv[1], priv[2 * j], - priv[2 * j + 1]); - } - } - } - if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - priv[2 * j] *= scaling_factor; - priv[2 * j + 1] *= scaling_factor; - } + if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { + sg_dft_compute(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, + conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, + loc_twiddles, scaling_factor, + static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), + id_of_wi_in_fft, factor_sg, factor_wi, global_data.sg); + } else { + // Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; + // subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; + // subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; + auto loc_offset_store_view = + storage == complex_storage::INTERLEAVED_COMPLEX + ? subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft + : subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; + auto loc_offset_load_view = storage == complex_storage::INTERLEAVED_COMPLEX + ? subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi + : subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi; + sg_bluestein(priv, wi_private_scratch, loc_view, loc_twiddles, load_modifier_data, + store_modifier_data, conjugate_on_load, conjugate_on_store, apply_scale_factor, + scaling_factor, id_of_wi_in_fft, factor_sg, factor_wi, storage, working, + loc_offset_store_view, loc_offset_load_view, local_imag_offset, global_data.sg, + global_data); } if (working) { global_data.log_dump_private("data in registers after scaling:", priv, n_reals_per_wi); } - if (factor_sg == SubgroupSize && is_output_packed) { + if (factor_sg == SubgroupSize && is_output_packed && algorithm == detail::fft_algorithm::COOLEY_TUKEY) { // in this case we get fully coalesced memory access even without going through local memory // TODO we may want to tune maximal `FactorSG` for which we use direct stores. if (working) { @@ -553,24 +446,20 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag "storing transposed data from private to global memory (FactorSG == " "SubgroupSize) and packed layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view output_view{output, static_cast(factor_sg), - i * static_cast(n_reals_per_sg) + - static_cast(id_of_fft_in_sg * n_reals_per_fft) + - static_cast(id_of_wi_in_fft * 2)}; - copy_wi<2>(global_data, priv, output_view, factor_wi); + IdxGlobal output_offset = i * static_cast(n_reals_per_sg) + + static_cast(id_of_fft_in_sg * n_reals_per_fft) + + static_cast(id_of_wi_in_fft * 2); + subgroup_impl_local_private_copy<1, IdxGlobal>( + output, priv, {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, global_data, + detail::transfer_direction::PRIVATE_TO_GLOBAL); } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::strided_view output_real_view{output, static_cast(factor_sg), - i * static_cast(n_cplx_per_sg) + - static_cast(id_of_fft_in_sg * fft_size) + - static_cast(id_of_wi_in_fft)}; - detail::strided_view output_imag_view{output_imag, static_cast(factor_sg), - i * static_cast(n_cplx_per_sg) + - static_cast(id_of_fft_in_sg * fft_size) + - static_cast(id_of_wi_in_fft)}; - copy_wi(global_data, priv_real_view, output_real_view, factor_wi); - copy_wi(global_data, priv_imag_view, output_imag_view, factor_wi); + IdxGlobal output_offset = i * static_cast(n_cplx_per_sg) + + static_cast(id_of_fft_in_sg * fft_size) + + static_cast(id_of_wi_in_fft); + subgroup_impl_local_private_copy<1, 1, IdxGlobal>( + output, output_imag, priv, {{{static_cast(factor_sg)}, {output_offset}}}, {{{2}, {0}}}, + {{{static_cast(factor_sg)}, {output_offset}}}, {{{2}, {1}}}, factor_wi, global_data, + detail::transfer_direction::PRIVATE_TO_GLOBAL); } } } else if (is_output_batch_interleaved) { @@ -597,10 +486,10 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global( __func__, "storing transposed data from private to local memory (FactorSG != SubgroupSize)"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view strided_local_view{ - loc_view, factor_sg, - subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft}; - copy_wi<2>(global_data, priv, strided_local_view, factor_wi); + Idx loc_view_offset = + subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; + subgroup_impl_local_private_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, factor_wi, + global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } else { detail::strided_view priv_real_view{priv, 2}; detail::strided_view priv_imag_view{priv, 2, 1}; @@ -611,49 +500,61 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft + local_imag_offset}; copy_wi(global_data, priv_real_view, local_real_view, factor_wi); copy_wi(global_data, priv_imag_view, local_imag_view, factor_wi); + // Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; + // subgroup_impl_local_private_copy<1, 1, Idx>( + // loc_view, loc_view, priv, {{{factor_sg}, {local_offset}}}, {{{2}, {0}}}, + // {{{factor_sg}, {loc_view_offset + local_imag_offset}}}, {{{2}, {1}}}, factor_wi, global_data, + // detail::transfer_direction::PRIVATE_TO_LOCAL); } } sycl::group_barrier(global_data.sg); global_data.log_dump_local("computed data in local memory:", loc, n_reals_per_fft); global_data.log_message_global( __func__, "storing transposed data from local to global memory (FactorSG != SubgroupSize)"); - if (is_output_packed) { - const IdxGlobal global_output_offset = n_io_reals_per_fft * (i - static_cast(id_of_fft_in_sg)); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local2global(global_data, loc_view, output, - n_ffts_worked_on_by_sg * n_reals_per_fft, local_offset, - global_output_offset); + if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { + if (is_output_packed) { + const IdxGlobal global_output_offset = n_io_reals_per_fft * (i - static_cast(id_of_fft_in_sg)); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + local2global(global_data, loc_view, output, + n_ffts_worked_on_by_sg * n_reals_per_fft, local_offset, + global_output_offset); + } else { + local2global( + global_data, loc_view, output, n_ffts_worked_on_by_sg * fft_size, local_offset, global_output_offset); + local2global(global_data, loc_view, output_imag, + n_ffts_worked_on_by_sg * fft_size, + local_offset + local_imag_offset, global_output_offset); + } } else { - local2global( - global_data, loc_view, output, n_ffts_worked_on_by_sg * fft_size, local_offset, global_output_offset); - local2global(global_data, loc_view, output_imag, - n_ffts_worked_on_by_sg * fft_size, - local_offset + local_imag_offset, global_output_offset); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + const IdxGlobal global_output_offset = + 2 * output_distance * (i - static_cast(id_of_fft_in_sg)); + global_data.log_message_global(__func__, "storing data from local to unpacked global memory"); + subgroup_impl_local2global_strided_copy( + output, loc_view, {output_distance * 2, output_stride * 2, 1}, {committed_length * 2, 2, 1}, + global_output_offset, local_offset, {n_ffts_worked_on_by_sg, fft_size, 2}, global_data, + detail::transfer_direction::LOCAL_TO_GLOBAL); + } else { + const IdxGlobal global_output_offset = output_distance * (i - static_cast(id_of_fft_in_sg)); + subgroup_impl_local2global_strided_copy( + output, output_imag, loc_view, {output_distance, output_stride}, {committed_length, 1}, + global_output_offset, local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, + global_data, detail::transfer_direction::LOCAL_TO_GLOBAL); + } } } else { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - const IdxGlobal global_output_offset = 2 * output_distance * (i - static_cast(id_of_fft_in_sg)); - std::array global_strides{output_distance * 2, output_stride * 2, 1}; - std::array local_strides{fft_size * 2, 2, 1}; - std::array copy_indices{n_ffts_worked_on_by_sg, fft_size, 2}; - detail::md_view global_output_view{output, global_strides, global_output_offset}; - detail::md_view local_output_view{loc_view, local_strides, local_offset}; - global_data.log_message_global(__func__, "storing data from local to unpacked global memory"); - copy_group(global_data, local_output_view, global_output_view, copy_indices); - } else { - const IdxGlobal global_output_offset = output_distance * (i - static_cast(id_of_fft_in_sg)); - std::array global_strides{output_distance, output_stride}; - std::array local_strides{fft_size, 1}; - std::array copy_indices{n_ffts_worked_on_by_sg, fft_size}; - - detail::md_view global_output_real_view{output, global_strides, global_output_offset}; - detail::md_view local_output_real_view{loc_view, local_strides, local_offset}; - detail::md_view global_output_imag_view{output_imag, global_strides, global_output_offset}; - detail::md_view local_output_imag_view{loc_view, local_strides, local_offset + local_imag_offset}; - global_data.log_message_global(__func__, "storing real data from local to unpacked global memory"); - copy_group(global_data, local_output_real_view, global_output_real_view, copy_indices); - global_data.log_message_global(__func__, "storing imaginary data from local to unpacked global memory"); - copy_group(global_data, local_output_imag_view, global_output_imag_view, copy_indices); + if (is_output_packed) { + auto global_ptr_offset = storage == complex_storage::INTERLEAVED_COMPLEX + ? 2 * committed_length * (i - static_cast(id_of_fft_in_sg)) + : committed_length * (i - static_cast(id_of_fft_in_sg)); + auto loc_view_offset = storage == complex_storage::INTERLEAVED_COMPLEX + ? 2 * factor_sg * factor_wi * subgroup_id + : factor_sg * factor_wi * subgroup_id; + auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; + subgroup_impl_bluestein_localglobal_packed_copy( + output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, + loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, + detail::transfer_direction::LOCAL_TO_GLOBAL, global_data); } } sycl::group_barrier(global_data.sg); @@ -666,16 +567,24 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag template template struct committed_descriptor_impl::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& /*dimension_data*/, + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& dimension_data, std::vector& kernels) { PORTFFT_LOG_FUNCTION_ENTRY(); const auto& kernel_data = kernels.at(0); Idx factor_wi = kernel_data.factors[0]; Idx factor_sg = kernel_data.factors[1]; + std::size_t twiddles_alloc_size = [&]() { + if (dimension_data.is_prime) { + std::cout << "DIMENSION IS INDEED PRIME " << std::endl; + // sg twiddles + load_modifiers + store_modifiers + return 6 * dimension_data.length; + } + return 2 * dimension_data.length; + }(); PORTFFT_LOG_TRACE("Allocating global memory for twiddles for subgroup implementation. Allocation size", kernel_data.length * 2); Scalar* res = sycl::aligned_alloc_device( - alignof(sycl::vec), kernel_data.length * 2, desc.queue); + alignof(sycl::vec), twiddles_alloc_size, desc.queue); sycl::range<2> kernel_range({static_cast(factor_sg), static_cast(factor_wi)}); desc.queue.submit([&](sycl::handler& cgh) { PORTFFT_LOG_TRACE("Launching twiddle calculation kernel for subgroup implementation with global size", factor_sg, @@ -686,6 +595,15 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn sg_calc_twiddles(factor_sg, factor_wi, n, k, res); }); }); + if (dimension_data.is_prime) { + std::vector bluestein_twiddles_host_ptr(4 * dimension_data.length, 0); + detail::populate_bluestein_input_modifiers(bluestein_twiddles_host_ptr.data(), dimension_data.committed_length, + dimension_data.length); + detail::populate_fft_chirp_signal(bluestein_twiddles_host_ptr.data() + 2 * dimension_data.length, + dimension_data.committed_length, dimension_data.length); + desc.queue.copy(bluestein_twiddles_host_ptr.data(), res + 2 * dimension_data.length, 4 * dimension_data.length) + .wait(); + } desc.queue.wait(); // waiting once here can be better than depending on the event // for all future calls to compute return res; @@ -723,6 +641,7 @@ struct committed_descriptor_impl::run_kernel_struct loc(local_elements, cgh); sycl::local_accessor loc_twiddles(twiddle_elements, cgh); + auto fft_size = dimension_data.length; #ifdef PORTFFT_KERNEL_LOG sycl::stream s{1024 * 16 * 16, 1024 * 8, cgh}; #endif @@ -743,10 +662,19 @@ struct committed_descriptor_impl::run_kernel_struct(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, - &in_imag_acc_or_usm[0] + input_offset, - &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc_twiddles[0], - n_transforms, twiddles, global_data, kh); + detail::fft_algorithm algorithm = kh.get_specialization_constant(); + if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { + detail::subgroup_impl(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, + &in_imag_acc_or_usm[0] + input_offset, + &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc_twiddles[0], + n_transforms, twiddles, global_data, kh); + } else { + detail::subgroup_impl(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, + &in_imag_acc_or_usm[0] + input_offset, + &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc_twiddles[0], + n_transforms, twiddles, global_data, kh, twiddles + 2 * fft_size, + twiddles + 4 * fft_size); + } global_data.log_message_global("Exiting subgroup kernel"); }); }); diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index dbbca454..b4cdbfe4 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -21,6 +21,7 @@ #ifndef PORTFFT_DISPATCHER_WORKGROUP_DISPATCHER_HPP #define PORTFFT_DISPATCHER_WORKGROUP_DISPATCHER_HPP +#include "portfft/common/bluestein.hpp" #include "portfft/common/helpers.hpp" #include "portfft/common/logging.hpp" #include "portfft/common/memory_views.hpp" @@ -367,7 +368,7 @@ struct committed_descriptor_impl::num_scalars_in_local_mem_struc template template struct committed_descriptor_impl::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& /*dimension_data*/, + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& dimension_data, std::vector& kernels) { PORTFFT_LOG_FUNCTION_ENTRY(); const auto& kernel_data = kernels.at(0); @@ -375,14 +376,16 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn Idx factor_sg_n = kernel_data.factors[1]; Idx factor_wi_m = kernel_data.factors[2]; Idx factor_sg_m = kernel_data.factors[3]; - Idx fft_size = static_cast(kernel_data.length); + Idx fft_size = static_cast(dimension_data.length); Idx n = factor_wi_n * factor_sg_n; Idx m = factor_wi_m * factor_sg_m; - Idx res_size = 2 * (m + n + fft_size); + std::size_t res_size = 2 * static_cast((m + n + fft_size)); + if (dimension_data.is_prime) { + res_size += 4 * dimension_data.length; + } PORTFFT_LOG_TRACE("Allocating global memory for twiddles for workgroup implementation. Allocation size", res_size); - Scalar* res = - sycl::aligned_alloc_device(alignof(sycl::vec), - static_cast(res_size), desc.queue); + Scalar* res = sycl::aligned_alloc_device( + alignof(sycl::vec), res_size, desc.queue); desc.queue.submit([&](sycl::handler& cgh) { PORTFFT_LOG_TRACE( "Launching twiddle calculation kernel for factor 1 of workgroup implementation with global size", factor_sg_n, diff --git a/src/portfft/enums.hpp b/src/portfft/enums.hpp index 19dd2019..07c87bf0 100644 --- a/src/portfft/enums.hpp +++ b/src/portfft/enums.hpp @@ -83,6 +83,9 @@ enum class elementwise_multiply { NOT_APPLIED, APPLIED }; enum class apply_scale_factor { NOT_APPLIED, APPLIED }; enum class complex_conjugate { NOT_APPLIED, APPLIED }; + +enum class fft_algorithm { COOLEY_TUKEY, BLUESTEIN }; + } // namespace detail } // namespace portfft diff --git a/src/portfft/specialization_constant.hpp b/src/portfft/specialization_constant.hpp index 713ff358..d4452558 100644 --- a/src/portfft/specialization_constant.hpp +++ b/src/portfft/specialization_constant.hpp @@ -56,5 +56,8 @@ constexpr static sycl::specialization_id SpecConstCon constexpr static sycl::specialization_id SpecConstScaleFactorFloat{}; constexpr static sycl::specialization_id SpecConstScaleFactorDouble{}; +constexpr static sycl::specialization_id SpecConstFFTAlgorithm{}; +constexpr static sycl::specialization_id SpecConstCommittedLength{}; + } // namespace portfft::detail #endif diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index db837e3e..68f8ab4b 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -89,24 +89,27 @@ constexpr bool can_cast_safely(const InputType& x) { * The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or not as an input, * and should return a boolean indicating whether or not the factor size can fit in any of the implementation. * @param transposed whether or not the factor will be computed in a BATCH_INTERLEAVED format - * @return + * @return Optionally returns the factor that fits in one of the existing implementations, std::nullopt otherwise */ template -IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_target_level, bool transposed) { +std::optional factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_target_level, + bool transposed) { PORTFFT_LOG_FUNCTION_ENTRY(); IdxGlobal fact_1 = factor_size; if (check_and_select_target_level(fact_1, transposed)) { + std::cout << fact_1 << std::endl; return fact_1; } if ((detail::factorize(fact_1) == 1)) { - throw unsupported_configuration("Large prime sized factors are not supported at the moment"); + return std::nullopt; } do { fact_1 = detail::factorize(fact_1); if (fact_1 == 1) { - throw internal_error("Factorization Failed !"); + return std::nullopt; } } while (!check_and_select_target_level(fact_1)); + std::cout << fact_1 << std::endl; return fact_1; } @@ -118,17 +121,24 @@ IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_targe * implementations. The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or * not as an input, and should return a boolean indicating whether or not the factor size can fit in any of the * implementation. + * @return Whether or not a prime sized that does not fit in workitem implementation was encountered */ template -void factorize_input(IdxGlobal input_size, F&& check_and_select_target_level) { +bool factorize_input(IdxGlobal input_size, F&& check_and_select_target_level) { PORTFFT_LOG_FUNCTION_ENTRY(); if (detail::factorize(input_size) == 1) { - throw unsupported_configuration("Large Prime sized FFTs are currently not supported"); + return true; } IdxGlobal temp = 1; while (input_size / temp != 1) { - temp *= factorize_input_impl(input_size / temp, check_and_select_target_level, true); + auto factor_size = factorize_input_impl(input_size / temp, check_and_select_target_level, true); + if (factor_size.has_value()) { + temp *= factor_size.value(); + } else { + return true; + } } + return false; } /** @@ -245,6 +255,15 @@ detail::layout get_layout(const Descriptor& desc, direction dir) { return detail::layout::UNPACKED; } +/** + * Gets the appropriate padded size for the Bluestein algorithm + * @param input_size The committed length of the dft transform + * @return The padded input size for which the FFT transform will run + */ +inline IdxGlobal get_bluestein_padded_size(IdxGlobal input_size) { + return static_cast(std::pow(2, ceil(log(static_cast(2 * input_size)) / log(2.0)))); +} + } // namespace detail } // namespace portfft #endif diff --git a/test/unit_test/fft_test_utils.hpp b/test/unit_test/fft_test_utils.hpp index 90941d68..bc84694c 100644 --- a/test/unit_test/fft_test_utils.hpp +++ b/test/unit_test/fft_test_utils.hpp @@ -273,6 +273,7 @@ std::enable_if_t check_fft( const std::vector& host_reference_output, const std::vector& host_input_imag, std::vector& host_output_imag, const std::vector& host_reference_output_imag, double tolerance) { + std::cout << "I AM IN CHECK FFT USM " << std::endl; auto committed_descriptor = desc.commit(queue); const bool is_oop = desc.placement == placement::OUT_OF_PLACE; @@ -338,6 +339,14 @@ std::enable_if_t check_fft( host_output_imag.size(), {fft_event}); } queue.wait_and_throw(); + std::cout << "PRINTING REFERENCE DATA " << std::endl; + for (auto n : host_reference_output) { + std::cout << n << " "; + } + std::cout << std::endl; + for (auto n : host_output) { + std::cout << n << " "; + } if constexpr (Storage == complex_storage::SPLIT_COMPLEX) { verify_dft(desc, host_reference_output, host_output, tolerance, host_reference_output_imag, host_output_imag); diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index 026f27c1..5a1bef4f 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -156,6 +156,12 @@ INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest, ::testing::Values(sizes_t{9800}, sizes_t{15360}, sizes_t{68640}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P(PrimeSizedTest, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + all_valid_placement_layouts, fwd_only, complex_storages, ::testing::Values(1, 8), + ::testing::Values(sizes_t{29}, sizes_t{53}, sizes_t{89}))), + test_params_print()); + // Backward FFT test suite INSTANTIATE_TEST_SUITE_P(BackwardTest, FFTTest, ::testing::ConvertGenerator(::testing::Combine( From 8cc05db5c58a77dd1d250d2b96f395d5dde2e1a9 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Fri, 8 Mar 2024 17:52:50 +0000 Subject: [PATCH 02/22] resolve warnings and add backward and multi-dim subgroup level small prime tests --- src/portfft/committed_descriptor_impl.hpp | 5 +-- src/portfft/common/global.hpp | 40 ++++++------------- src/portfft/dispatcher/global_dispatcher.hpp | 6 +-- .../dispatcher/subgroup_dispatcher.hpp | 1 - test/unit_test/instantiate_fft_tests.hpp | 8 +++- 5 files changed, 24 insertions(+), 36 deletions(-) diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index 8182e08d..bae2e847 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -89,8 +89,8 @@ class committed_descriptor_impl { Scalar1* output, const TIn& input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, - IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, - const std::vector& dependencies, sycl::queue& queue); + IdxGlobal batch_start, Idx total_factors, complex_storage storage, const std::vector& dependencies, + sycl::queue& queue); template friend sycl::event detail::transpose_level( @@ -490,7 +490,6 @@ class committed_descriptor_impl { const auto apply_scale = is_final_factor && is_final_dim ? detail::apply_scale_factor::APPLIED : detail::apply_scale_factor::NOT_APPLIED; - Idx length{}; IdxGlobal forward_stride{}; IdxGlobal backward_stride{}; IdxGlobal forward_distance{}; diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index bb225a66..c0cd07e9 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -124,7 +124,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, * @param store_modifier store modifier data * @param input_loc pointer to local memory for storing the input * @param twiddles_loc pointer to local memory for storing the twiddles for sub-implementation - * @param store_modifier_loc pointer to local memory for store modifier data * @param factors pointer to global memory containing factors of the input * @param inner_batches pointer to global memory containing the inner batch for each factor * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors @@ -135,10 +134,10 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, template PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag, const Scalar* implementation_twiddles, const Scalar* store_modifier_data, - Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc, - const IdxGlobal* factors, const IdxGlobal* inner_batches, - const IdxGlobal* inclusive_scan, IdxGlobal batch_size, - detail::global_data_struct<1> global_data, sycl::kernel_handler& kh) { + Scalar* input_loc, Scalar* twiddles_loc, const IdxGlobal* factors, + const IdxGlobal* inner_batches, const IdxGlobal* inclusive_scan, + IdxGlobal batch_size, detail::global_data_struct<1> global_data, + sycl::kernel_handler& kh) { complex_storage storage = kh.get_specialization_constant(); auto level = kh.get_specialization_constant(); Idx level_num = kh.get_specialization_constant(); @@ -292,7 +291,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl compute_level( Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, - IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, - const std::vector& dependencies, sycl::queue& queue) { + IdxGlobal batch_start, Idx total_factors, complex_storage storage, const std::vector& dependencies, + sycl::queue& queue) { PORTFFT_LOG_FUNCTION_ENTRY(); constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; IdxGlobal local_range = kd_struct.local_range; IdxGlobal global_range = kd_struct.global_range; IdxGlobal batch_size = kd_struct.batch_size; std::size_t local_memory_for_input = kd_struct.local_mem_required; - std::size_t local_mem_for_store_modifier = [&]() -> std::size_t { - if (factor_id < total_factors - 1) { - if (kd_struct.level == detail::level::WORKITEM || kd_struct.level == detail::level::WORKGROUP) { - return 0; - } - if (kd_struct.level == detail::level::SUBGROUP) { - return 0; - } - } - return std::size_t(1); - }(); std::size_t loc_mem_for_twiddles = [&]() { if (kd_struct.level == detail::level::WORKITEM) { return std::size_t(1); @@ -340,15 +327,13 @@ std::vector compute_level( const IdxGlobal* inclusive_scan = factors_triple + 2 * total_factors; const Idx vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; std::vector events; - PORTFFT_LOG_TRACE("Local mem requirement - input:", local_memory_for_input, "store modifiers", - local_mem_for_store_modifier, "twiddles", loc_mem_for_twiddles, "total", - local_memory_for_input + local_mem_for_store_modifier + loc_mem_for_twiddles); + PORTFFT_LOG_TRACE("Local mem requirement - input:", local_memory_for_input, "twiddles", loc_mem_for_twiddles, "total", + local_memory_for_input + loc_mem_for_twiddles); for (Idx batch_in_l2 = 0; batch_in_l2 < num_batches_in_l2 && batch_in_l2 + batch_start < n_transforms; batch_in_l2++) { events.push_back(queue.submit([&](sycl::handler& cgh) { sycl::local_accessor loc_for_input(local_memory_for_input, cgh); sycl::local_accessor loc_for_twiddles(loc_mem_for_twiddles, cgh); - sycl::local_accessor loc_for_modifier(local_mem_for_store_modifier, cgh); auto in_acc_or_usm = detail::get_access(input, cgh); auto in_imag_acc_or_usm = detail::get_access(input_imag, cgh); cgh.use_kernel_bundle(kd_struct.exec_bundle); @@ -388,11 +373,10 @@ std::vector compute_level( s, global_logging_config, #endif it}; - dispatch_level(&in_acc_or_usm[0] + input_batch_offset, offset_output, - &in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag, - subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], - &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, - inner_batches, inclusive_scan, batch_size, global_data, kh); + dispatch_level( + &in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, + offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], + &loc_for_twiddles[0], factors_triple, inner_batches, inclusive_scan, batch_size, global_data, kh); }); })); } diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 8b67e55a..3893682c 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -333,7 +333,7 @@ struct committed_descriptor_impl::run_kernel_struct(i) * committed_size + input_offset, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, + static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), dimension_data.num_factors, storage, {event}, desc.queue); detail::dump_device(desc.queue, "after factor 0:", desc.scratch_ptr_1.get(), desc.params.number_of_transforms * dimension_data.length * 2, l2_events); @@ -350,8 +350,8 @@ struct committed_descriptor_impl::run_kernel_struct(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - dimension_data.num_factors, storage, l2_events, desc.queue); + static_cast(num_batches), static_cast(i), dimension_data.num_factors, storage, + l2_events, desc.queue); intermediate_twiddles_offset += 2 * current_kernel.batch_size * static_cast(current_kernel.length); impl_twiddle_offset += detail::increment_twiddle_offset(current_kernel.level, static_cast(current_kernel.length)); diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index c02908ff..7cc5a2f4 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -394,7 +394,6 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } - // sycl::group_barrier(global_data.sg); global_data.log_dump_local("data in local memory:", loc_view, n_reals_per_fft); if (working) { diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index 5a1bef4f..d8bfd843 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -158,10 +158,16 @@ INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest, INSTANTIATE_TEST_SUITE_P(PrimeSizedTest, FFTTest, ::testing::ConvertGenerator(::testing::Combine( - all_valid_placement_layouts, fwd_only, complex_storages, ::testing::Values(1, 8), + all_valid_placement_layouts, both_directions, complex_storages, ::testing::Values(1, 8), ::testing::Values(sizes_t{29}, sizes_t{53}, sizes_t{89}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P(PrimeSizedMultiDimensionalTest, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + all_valid_placement_layouts, both_directions, complex_storages, ::testing::Values(1, 8), + ::testing::Values(sizes_t{29, 53}, sizes_t{53, 89}, sizes_t{89, 89}))), + test_params_print()); + // Backward FFT test suite INSTANTIATE_TEST_SUITE_P(BackwardTest, FFTTest, ::testing::ConvertGenerator(::testing::Combine( From 5237a3746bb3c15458feb7d9c9c2a37bfa38d4df Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Sat, 9 Mar 2024 14:38:03 +0000 Subject: [PATCH 03/22] further changes --- src/portfft/committed_descriptor_impl.hpp | 41 ++++++---- src/portfft/common/bluestein.hpp | 2 - src/portfft/common/subgroup.hpp | 75 +++++++++++-------- src/portfft/dispatcher/global_dispatcher.hpp | 3 +- .../dispatcher/subgroup_dispatcher.hpp | 7 +- test/unit_test/fft_test_utils.hpp | 24 +++--- test/unit_test/instantiate_fft_tests.hpp | 15 ++-- 7 files changed, 102 insertions(+), 65 deletions(-) diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index bae2e847..436c0886 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -167,7 +167,11 @@ class committed_descriptor_impl { length(length), committed_length(committed_length), used_sg_size(used_sg_size), - is_prime(is_prime) {} + is_prime(is_prime) { + if (is_prime && level != detail::level::SUBGROUP) { + throw unsupported_configuration("Prime sizes that not fit in the subgroup implementation are not supported"); + } + } }; std::vector dimensions; @@ -233,6 +237,8 @@ class committed_descriptor_impl { {{detail::level::WORKITEM, ids, {static_cast(fft_size)}}}}; } if (detail::fits_in_sg(fft_size, SubgroupSize)) { + std::cout << "I AM NOW SELECTING SG IMPLEMENTATION " << std::endl; + ; Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); Idx factor_wi = static_cast(fft_size) / factor_sg; // This factorization is duplicated in the dispatch logic on the device. @@ -279,6 +285,7 @@ class committed_descriptor_impl { std::vector, std::vector>> param_vec; auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { if (detail::fits_in_wi(factor_size)) { + std::cout << "I AM SELECTING THE WI IMPLEMENTATION " << std::endl; // Throughout we have assumed there would always be enough local memory for the WI implementation. param_vec.emplace_back(detail::level::WORKITEM, detail::get_ids(), @@ -293,7 +300,7 @@ class committed_descriptor_impl { if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { std::size_t input_scalars = num_scalars_in_local_mem(detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg, + {static_cast(factor_wi), static_cast(factor_sg)}, temp_num_sgs_in_wg, batch_interleaved_layout ? layout::BATCH_INTERLEAVED : layout::PACKED); std::size_t twiddle_scalars = 2 * static_cast(factor_size); return (sizeof(Scalar) * (input_scalars + twiddle_scalars)) < static_cast(local_memory_size); @@ -302,6 +309,7 @@ class committed_descriptor_impl { }(); if (detail::fits_in_sg(factor_size, SubgroupSize) && fits_in_local_memory_subgroup && !PORTFFT_SLOW_SG_SHUFFLES) { + std::cout << "I AM SELECTING THE SG IMPLEMENTATION " << std::endl; Idx factor_sg = detail::factorize_sg(static_cast(factor_size), SubgroupSize); Idx factor_wi = static_cast(factor_size) / factor_sg; PORTFFT_LOG_TRACE("Subgroup kernel for factor:", factor_size, "with factor_wi:", factor_wi, @@ -314,11 +322,8 @@ class committed_descriptor_impl { return false; }; bool encountered_large_prime = detail::factorize_input(fft_size, check_and_select_target_level); - std::cout << "encountered_large_prime = " << encountered_large_prime << std::endl; if (encountered_large_prime) { - std::cout << "I HAVE ENCOUNTERED A LARGE PRIME, fft size = " << fft_size << std::endl; IdxGlobal padded_size = detail::get_bluestein_padded_size(fft_size); - std::cout << "PADDED FFT SIZE = " << padded_size << std::endl; return prepare_implementation(padded_size); } return {detail::level::GLOBAL, static_cast(fft_size), param_vec}; @@ -527,16 +532,22 @@ class committed_descriptor_impl { auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); - if (factor_size != static_cast(params.lengths[dimension_num])) { + if (factor_size != static_cast(params.lengths[dimension_num]) && !is_global) { in_bundle.template set_specialization_constant(detail::fft_algorithm::BLUESTEIN); in_bundle.template set_specialization_constant( static_cast(params.lengths[dimension_num])); } else { // TODO: This needs to change in the case of global + if (is_global) { + in_bundle.template set_specialization_constant( + static_cast(factor_size)); + } else { + in_bundle.template set_specialization_constant( + static_cast(params.lengths[dimension_num])); + } + in_bundle.template set_specialization_constant( detail::fft_algorithm::COOLEY_TUKEY); - in_bundle.template set_specialization_constant( - static_cast(params.lengths[dimension_num])); } set_spec_constants(top_level, in_bundle, factor_size, factors, detail::elementwise_multiply::NOT_APPLIED, @@ -583,11 +594,6 @@ class committed_descriptor_impl { } } - std::cout << "FFT SIZE = " << fft_size << std::endl; - if (top_level == detail::level::SUBGROUP) { - std::cout << "PADDED FFT SIZE LEVEL is SUBGROUP " << std::endl; - } - // exit(-10); if (is_compatible) { auto forward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::FORWARD, dimension_num); @@ -947,6 +953,15 @@ class committed_descriptor_impl { const auto input_layout = detail::get_layout(params, compute_direction); const auto output_layout = detail::get_layout(params, inv(compute_direction)); + if (dimensions.back().is_prime) { + if (input_layout == detail::layout::UNPACKED || output_layout == detail::layout::UNPACKED) { + throw unsupported_configuration("Unsupported configuration for prime sized DFTs"); + } + if (input_layout == detail::layout::PACKED && output_layout != detail::layout::PACKED) { + throw unsupported_configuration("Unsupported configuration for prime sized DFTs"); + } + } + // currently multi-dimensional transforms are implemented just for default (PACKED) data layout const bool multi_dim_supported = input_layout == detail::layout::PACKED && output_layout == detail::layout::PACKED; if (n_dimensions != 1 && !multi_dim_supported) { diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp index eda5c786..5be97cc2 100644 --- a/src/portfft/common/bluestein.hpp +++ b/src/portfft/common/bluestein.hpp @@ -38,7 +38,6 @@ namespace detail { */ template void populate_fft_chirp_signal(T* ptr, std::size_t committed_size, std::size_t dimension_size) { - std::cout << "committed_size = " << committed_size << " padded size = " << dimension_size << std::endl; using complex_t = std::complex; std::vector chirp_signal(dimension_size, 0); std::vector chirp_fft(dimension_size, 0); @@ -63,7 +62,6 @@ void populate_fft_chirp_signal(T* ptr, std::size_t committed_size, std::size_t d */ template void populate_bluestein_input_modifiers(T* ptr, std::size_t committed_size, std::size_t dimension_size) { - std::cout << "committed_size = " << committed_size << " padded size = " << dimension_size << std::endl; using complex_t = std::complex; std::vector scratch(dimension_size, 0); for (std::size_t i = 0; i < committed_size; i++) { diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index 825b4f9c..f27133bf 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -432,6 +432,32 @@ PORTFFT_INLINE void subgroup_impl_bluestein_localglobal_packed_copy( sycl::group_barrier(sg); } +/** + * Performs all the computations to be done in the private memory for the subgroup implementation + * + * @tparam SubgroupSize Subgroup Size + * @tparam T Scalar Type + * @tparam LocView View of the local memory + * @param priv private memory array on which the computations will be done + * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft + * @param apply_load_modifier Whether or not modifiers need to be applied before the fft computation + * @param apply_store_modifier Whether or not the modifiers need to be applied after the fft computation + * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation + * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation + * @param scale_factor_applied Whether or not scale factor is applied + * @param load_modifier_data Global memory pointer containing the load modifier data, assumed aligned to at least + * sycl::vec + * @param store_modifier_data Global memory pointer containing the store modifier data, assumed aligned to at least + * sycl::vec + * @param twiddles_loc_view View of the local memory containing the twiddles + * @param scale_factor Value of the scale factor + * @param modifier_start_offset offset to be applied to the load/store modifier pointers + * @param id_of_wi_in_fft workitem id withing the fft + * @param factor_sg Number of workitems participating for one transform + * @param factor_wi Number of complex elements per workitem for each transform + * @param sg sub group + * @return PORTFFT_INLINE + */ template PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementwise_multiply apply_load_modifier, detail::elementwise_multiply apply_store_modifier, @@ -492,16 +518,11 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved(T* priv, T* priv_scratch, Loc sycl::sub_group& sg, detail::global_data_struct<1>& global_data) { sg_dft_compute( priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, - conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, load_modifier, - store_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); - - PORTFFT_UNROLL - for (Idx i = 0; i < 2 * factor_wi; i++) { - priv[i] = (priv[i] / (static_cast(factor_sg * factor_wi))); - } + conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, + store_modifier, twiddles_loc, static_cast(1. / (static_cast(factor_sg * factor_wi))), 0, id_of_wi_in_fft, + factor_sg, factor_wi, sg); if (wi_working) { - // Store back to local memory only if (storage == complex_storage::INTERLEAVED_COMPLEX) { subgroup_impl_local_private_copy<2, Idx>( loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, @@ -532,14 +553,14 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved(T* priv, T* priv_scratch, Loc } } - auto conjugate_on_output = conjugate_on_store == detail::complex_conjugate::APPLIED - ? detail::complex_conjugate::NOT_APPLIED - : detail::complex_conjugate::APPLIED; - sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, - conjugate_on_output, scale_applied, static_cast(nullptr), load_modifier, - twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), + load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + detail::conjugate_inplace(priv, factor_wi); + } } template @@ -548,19 +569,14 @@ void sg_bluestein(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, T scale_factor, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, bool wi_working, Idx loc_offset_store_view, Idx loc_offset_load_view, Idx local_imag_offset, - sycl::sub_group sg, detail::global_data_struct<1>& global_data) { - // for (Idx i = 0; i < 2 * factor_wi; i++) { - // priv[i] = 2; - // } + sycl::sub_group& sg, detail::global_data_struct<1>& global_data) { sg_dft_compute( priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, - conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, load_modifier, - store_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, + store_modifier, loc_twiddles, static_cast(1. / static_cast(factor_sg * factor_wi)), 0, id_of_wi_in_fft, + factor_sg, factor_wi, sg); - PORTFFT_UNROLL - for (Idx i = 0; i < 2 * factor_wi; i++) { - priv[i] = (priv[i] / (static_cast(factor_sg * factor_wi))); - } + sycl::group_barrier(sg); if (wi_working) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { @@ -590,14 +606,13 @@ void sg_bluestein(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& } } - auto conjugate_on_output = conjugate_on_store == detail::complex_conjugate::APPLIED - ? detail::complex_conjugate::NOT_APPLIED - : detail::complex_conjugate::APPLIED; - sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, - conjugate_on_output, scale_applied, static_cast(nullptr), load_modifier, - loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), + load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + detail::conjugate_inplace(priv, factor_wi); + } } }; // namespace portfft diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 3893682c..29cd38d4 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -147,7 +147,8 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn PORTFFT_LOG_TRACE("Allocating global memory for twiddles for workgroup implementation. Allocation size", mem_required_for_twiddles); Scalar* device_twiddles = - sycl::malloc_device(static_cast(mem_required_for_twiddles), desc.queue); + sycl::aligned_alloc_device(alignof(sycl::vec), + static_cast(mem_required_for_twiddles), desc.queue); // Helper Lambda to calculate twiddles auto calculate_twiddles = [](IdxGlobal N, IdxGlobal M, IdxGlobal& offset, Scalar* ptr) { diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 7cc5a2f4..d76922f9 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -391,6 +391,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag const_cast(input), const_cast(input_imag), loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, detail::transfer_direction::GLOBAL_TO_LOCAL, global_data); + } else { + // TODO: Bluestein Strided copy } } @@ -461,7 +463,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag detail::transfer_direction::PRIVATE_TO_GLOBAL); } } - } else if (is_output_batch_interleaved) { + } else if (is_output_batch_interleaved && algorithm == detail::fft_algorithm::COOLEY_TUKEY) { if (working) { global_data.log_message_global(__func__, "Storing data from private to Global with batch interleaved layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { @@ -554,6 +556,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, detail::transfer_direction::LOCAL_TO_GLOBAL, global_data); + } else { + // TODO: Blustein Strided Copy } } sycl::group_barrier(global_data.sg); @@ -574,7 +578,6 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn Idx factor_sg = kernel_data.factors[1]; std::size_t twiddles_alloc_size = [&]() { if (dimension_data.is_prime) { - std::cout << "DIMENSION IS INDEED PRIME " << std::endl; // sg twiddles + load_modifiers + store_modifiers return 6 * dimension_data.length; } diff --git a/test/unit_test/fft_test_utils.hpp b/test/unit_test/fft_test_utils.hpp index bc84694c..12790d7e 100644 --- a/test/unit_test/fft_test_utils.hpp +++ b/test/unit_test/fft_test_utils.hpp @@ -273,7 +273,7 @@ std::enable_if_t check_fft( const std::vector& host_reference_output, const std::vector& host_input_imag, std::vector& host_output_imag, const std::vector& host_reference_output_imag, double tolerance) { - std::cout << "I AM IN CHECK FFT USM " << std::endl; + // std::cout << "I AM IN CHECK FFT USM " << std::endl; auto committed_descriptor = desc.commit(queue); const bool is_oop = desc.placement == placement::OUT_OF_PLACE; @@ -339,14 +339,14 @@ std::enable_if_t check_fft( host_output_imag.size(), {fft_event}); } queue.wait_and_throw(); - std::cout << "PRINTING REFERENCE DATA " << std::endl; - for (auto n : host_reference_output) { - std::cout << n << " "; - } - std::cout << std::endl; - for (auto n : host_output) { - std::cout << n << " "; - } + // std::cout << "PRINTING REFERENCE DATA " << std::endl; + // for (auto n : host_reference_output) { + // std::cout << n << " "; + // } + // std::cout << std::endl; + // for (auto n : host_output) { + // std::cout << n << " "; + // } if constexpr (Storage == complex_storage::SPLIT_COMPLEX) { verify_dft(desc, host_reference_output, host_output, tolerance, host_reference_output_imag, host_output_imag); @@ -471,7 +471,11 @@ void run_test(const test_params& params) { std::accumulate(params.lengths.begin(), params.lengths.end(), 1ull, std::multiplies())); // 2 * theoretical max L2 error of Cooley-Tukey double tolerance = 2 * std::numeric_limits::epsilon() * n_elems * std::log2(n_elems); - + auto num_prime_sizes = std::count_if(params.lengths.begin(), params.lengths.end(), + [](const std::size_t l) { return detail::factorize(l) == std::size_t(1); }); + if (num_prime_sizes > 0) { + tolerance *= 10; + } portfft::detail::dump_host("host_input:", host_input.data(), host_input.size()); portfft::detail::dump_host("host_input_imag:", host_input.data(), host_input.size()); diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index d8bfd843..a86e3474 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -158,15 +158,16 @@ INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest, INSTANTIATE_TEST_SUITE_P(PrimeSizedTest, FFTTest, ::testing::ConvertGenerator(::testing::Combine( - all_valid_placement_layouts, both_directions, complex_storages, ::testing::Values(1, 8), - ::testing::Values(sizes_t{29}, sizes_t{53}, sizes_t{89}))), + all_valid_placement_layouts, both_directions, complex_storages, + ::testing::Values(1, 8, 33000), ::testing::Values(sizes_t{31}, sizes_t{53}, sizes_t{89}))), test_params_print()); -INSTANTIATE_TEST_SUITE_P(PrimeSizedMultiDimensionalTest, FFTTest, - ::testing::ConvertGenerator(::testing::Combine( - all_valid_placement_layouts, both_directions, complex_storages, ::testing::Values(1, 8), - ::testing::Values(sizes_t{29, 53}, sizes_t{53, 89}, sizes_t{89, 89}))), - test_params_print()); +INSTANTIATE_TEST_SUITE_P( + PrimeSizedMultiDimensionalTest, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + all_valid_multi_dim_placement_layouts, both_directions, complex_storages, ::testing::Values(1, 8), + ::testing::Values(sizes_t{29, 53}, sizes_t{53, 89}, sizes_t{89, 89}, sizes_t{31, 89}, sizes_t{31, 53, 89}))), + test_params_print()); // Backward FFT test suite INSTANTIATE_TEST_SUITE_P(BackwardTest, FFTTest, From 5292f41d9af0efb68644f7a52ac8d22dfc7101a4 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 11 Mar 2024 08:19:58 +0000 Subject: [PATCH 04/22] modifier offset calculation bugfix --- src/portfft/committed_descriptor_impl.hpp | 8 ++------ src/portfft/dispatcher/global_dispatcher.hpp | 19 +++++++++---------- .../dispatcher/subgroup_dispatcher.hpp | 3 ++- .../dispatcher/workgroup_dispatcher.hpp | 3 --- src/portfft/utils.hpp | 2 -- 5 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index 436c0886..597d28f7 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -237,8 +237,6 @@ class committed_descriptor_impl { {{detail::level::WORKITEM, ids, {static_cast(fft_size)}}}}; } if (detail::fits_in_sg(fft_size, SubgroupSize)) { - std::cout << "I AM NOW SELECTING SG IMPLEMENTATION " << std::endl; - ; Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); Idx factor_wi = static_cast(fft_size) / factor_sg; // This factorization is duplicated in the dispatch logic on the device. @@ -285,7 +283,6 @@ class committed_descriptor_impl { std::vector, std::vector>> param_vec; auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { if (detail::fits_in_wi(factor_size)) { - std::cout << "I AM SELECTING THE WI IMPLEMENTATION " << std::endl; // Throughout we have assumed there would always be enough local memory for the WI implementation. param_vec.emplace_back(detail::level::WORKITEM, detail::get_ids(), @@ -303,20 +300,19 @@ class committed_descriptor_impl { {static_cast(factor_wi), static_cast(factor_sg)}, temp_num_sgs_in_wg, batch_interleaved_layout ? layout::BATCH_INTERLEAVED : layout::PACKED); std::size_t twiddle_scalars = 2 * static_cast(factor_size); - return (sizeof(Scalar) * (input_scalars + twiddle_scalars)) < static_cast(local_memory_size); + return (sizeof(Scalar) * (input_scalars + twiddle_scalars)) <= static_cast(local_memory_size); } return false; }(); if (detail::fits_in_sg(factor_size, SubgroupSize) && fits_in_local_memory_subgroup && !PORTFFT_SLOW_SG_SHUFFLES) { - std::cout << "I AM SELECTING THE SG IMPLEMENTATION " << std::endl; Idx factor_sg = detail::factorize_sg(static_cast(factor_size), SubgroupSize); Idx factor_wi = static_cast(factor_size) / factor_sg; PORTFFT_LOG_TRACE("Subgroup kernel for factor:", factor_size, "with factor_wi:", factor_wi, "and factor_sg:", factor_sg); param_vec.emplace_back(detail::level::SUBGROUP, detail::get_ids(), - std::vector{factor_sg, factor_wi}); + std::vector{factor_wi, factor_sg}); return true; } return false; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 29cd38d4..b4ba5a76 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -147,8 +147,7 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn PORTFFT_LOG_TRACE("Allocating global memory for twiddles for workgroup implementation. Allocation size", mem_required_for_twiddles); Scalar* device_twiddles = - sycl::aligned_alloc_device(alignof(sycl::vec), - static_cast(mem_required_for_twiddles), desc.queue); + sycl::malloc_device(static_cast(mem_required_for_twiddles), desc.queue); // Helper Lambda to calculate twiddles auto calculate_twiddles = [](IdxGlobal N, IdxGlobal M, IdxGlobal& offset, Scalar* ptr) { @@ -170,16 +169,16 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn counter = 0; for (const auto& kernel_data : kernels) { if (kernel_data.level == detail::level::SUBGROUP) { - for (Idx i = 0; i < kernel_data.factors.at(0); i++) { - for (Idx j = 0; j < kernel_data.factors.at(1); j++) { + for (Idx i = 0; i < kernel_data.factors.at(1); i++) { + for (Idx j = 0; j < kernel_data.factors.at(0); j++) { double theta = -2 * M_PI * static_cast(i * j) / static_cast(kernel_data.factors.at(0) * kernel_data.factors.at(1)); auto twiddle = std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); - host_memory[static_cast(offset + static_cast(j * kernel_data.factors.at(0) + i))] = + host_memory[static_cast(offset + static_cast(j * kernel_data.factors.at(1) + i))] = twiddle.real(); host_memory[static_cast( - offset + static_cast((j + kernel_data.factors.at(1)) * kernel_data.factors.at(0) + i))] = + offset + static_cast((j + kernel_data.factors.at(0)) * kernel_data.factors.at(1) + i))] = twiddle.imag(); } } @@ -273,10 +272,10 @@ struct committed_descriptor_impl::set_spec_constants_struct::inn PORTFFT_LOG_TRACE("SpecConstFftSize:", length); in_bundle.template set_specialization_constant(length); } else if (level == detail::level::SUBGROUP) { - PORTFFT_LOG_TRACE("SubgroupFactorWISpecConst:", factors[1]); - in_bundle.template set_specialization_constant(factors[1]); - PORTFFT_LOG_TRACE("SubgroupFactorSGSpecConst:", factors[0]); - in_bundle.template set_specialization_constant(factors[0]); + PORTFFT_LOG_TRACE("SubgroupFactorWISpecConst:", factors[0]); + in_bundle.template set_specialization_constant(factors[0]); + PORTFFT_LOG_TRACE("SubgroupFactorSGSpecConst:", factors[1]); + in_bundle.template set_specialization_constant(factors[1]); } } }; diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index d76922f9..fd360b8d 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -235,7 +235,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } IdxGlobal modifier_offset = - static_cast(n_reals_per_fft) * (i + static_cast(fft_idx_in_local + id_of_fft_in_sg)); + static_cast(n_reals_per_fft) * (i + static_cast(fft_idx_in_local)); if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { sg_dft_compute(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, @@ -397,6 +397,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } global_data.log_dump_local("data in local memory:", loc_view, n_reals_per_fft); + sycl::group_barrier(global_data.sg); if (working) { global_data.log_message_global(__func__, "loading non-transposed data from local to private memory"); diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index b4cdbfe4..db0860a7 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -380,9 +380,6 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn Idx n = factor_wi_n * factor_sg_n; Idx m = factor_wi_m * factor_sg_m; std::size_t res_size = 2 * static_cast((m + n + fft_size)); - if (dimension_data.is_prime) { - res_size += 4 * dimension_data.length; - } PORTFFT_LOG_TRACE("Allocating global memory for twiddles for workgroup implementation. Allocation size", res_size); Scalar* res = sycl::aligned_alloc_device( alignof(sycl::vec), res_size, desc.queue); diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index 68f8ab4b..2d63b0d5 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -97,7 +97,6 @@ std::optional factorize_input_impl(IdxGlobal factor_size, F&& check_a PORTFFT_LOG_FUNCTION_ENTRY(); IdxGlobal fact_1 = factor_size; if (check_and_select_target_level(fact_1, transposed)) { - std::cout << fact_1 << std::endl; return fact_1; } if ((detail::factorize(fact_1) == 1)) { @@ -109,7 +108,6 @@ std::optional factorize_input_impl(IdxGlobal factor_size, F&& check_a return std::nullopt; } } while (!check_and_select_target_level(fact_1)); - std::cout << fact_1 << std::endl; return fact_1; } From a3aa5d63a2b6ddba2237a36d1e5f28159b47eee8 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Mar 2024 06:13:04 +0000 Subject: [PATCH 05/22] initialize local memory with zeros to avoid nans --- src/portfft/common/subgroup.hpp | 114 +++++++++++------- .../dispatcher/subgroup_dispatcher.hpp | 66 +++++----- 2 files changed, 102 insertions(+), 78 deletions(-) diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index f27133bf..9e05ac9d 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -311,6 +311,19 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles sg_twiddles[(k + factor_wi) * factor_sg + n] = twiddle.imag(); } +template +PORTFFT_INLINE void subgroup_impl_global2local_strided_copy(const T* global_ptr, LocView& loc_view, + std::array strides_global, + std::array strides_local, + IdxGlobal offset_global, Idx offset_local, + std::array copy_strides, + detail::global_data_struct<1> global_data) { + detail::md_view global_md_view{global_ptr, strides_global, offset_global}; + detail::md_view local_md_view{loc_view, strides_local, offset_local}; + copy_group(global_data, global_md_view, local_md_view, copy_strides); +} + template PORTFFT_INLINE void subgroup_impl_local2global_strided_copy(T* global_ptr, LocView& loc_view, @@ -318,15 +331,10 @@ PORTFFT_INLINE void subgroup_impl_local2global_strided_copy(T* global_ptr, LocVi std::array strides_local, IdxGlobal offset_global, Idx offset_local, std::array copy_strides, - detail::global_data_struct<1> global_data, - detail::transfer_direction direction) { + detail::global_data_struct<1> global_data) { detail::md_view global_md_view{global_ptr, strides_global, offset_global}; detail::md_view local_md_view{loc_view, strides_local, offset_local}; - if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { - copy_group(global_data, global_md_view, local_md_view, copy_strides); - } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { - copy_group(global_data, local_md_view, global_md_view, copy_strides); - } + copy_group(global_data, local_md_view, global_md_view, copy_strides); } template strides_global, std::array strides_local, IdxGlobal offset_global, Idx local_offset, Idx local_imag_offset, - std::array copy_strides, detail::global_data_struct<1> global_data, - detail::transfer_direction direction) { + std::array copy_strides, detail::global_data_struct<1> global_data) { detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; - if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { - copy_group(global_data, global_md_real_view, local_md_real_view, copy_strides); - copy_group(global_data, global_md_imag_view, local_md_imag_view, copy_strides); - } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { - copy_group(global_data, local_md_real_view, global_md_real_view, copy_strides); - copy_group(global_data, local_md_imag_view, global_md_imag_view, copy_strides); - } + copy_group(global_data, local_md_real_view, global_md_real_view, copy_strides); + copy_group(global_data, local_md_imag_view, global_md_imag_view, copy_strides); +} + +template +PORTFFT_INLINE void subgroup_impl_global2local_strided_copy( + const T* global_ptr, const T* global_imag_ptr, LocView& loc_view, std::array strides_global, + std::array strides_local, IdxGlobal offset_global, Idx local_offset, Idx local_imag_offset, + std::array copy_strides, detail::global_data_struct<1> global_data) { + detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; + detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; + detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; + detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; + copy_group(global_data, global_md_real_view, local_md_real_view, copy_strides); + copy_group(global_data, global_md_imag_view, local_md_imag_view, copy_strides); } template @@ -390,42 +406,54 @@ PORTFFT_INLINE void subgroup_impl_local_private_copy( } template -PORTFFT_INLINE void subgroup_impl_bluestein_localglobal_packed_copy( +PORTFFT_INLINE void subgroup_impl_bluestein_local2global_packed_copy( TIn* global_ptr, TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, sycl::sub_group& sg, - complex_storage storage, detail::transfer_direction direction, detail::global_data_struct<1>& global_data) { + complex_storage storage, detail::global_data_struct<1>& global_data) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { PORTFFT_UNROLL for (Idx i = 0; i < n_ffts_in_sg; i++) { - if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { - global2local( - global_data, global_ptr, loc_view, 2 * committed_size, - static_cast(2 * i * committed_size) + global_ptr_offset, 2 * i * fft_size + loc_offset); - } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { - local2global(global_data, loc_view, global_ptr, 2 * committed_size, - 2 * i * fft_size + loc_offset, - global_ptr_offset + 2 * i * committed_size); - } + local2global(global_data, loc_view, global_ptr, 2 * committed_size, + 2 * i * fft_size + loc_offset, + global_ptr_offset + 2 * i * committed_size); } } else { PORTFFT_UNROLL for (Idx i = 0; i < n_ffts_in_sg; i++) { - if (direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { - global2local( - global_data, global_ptr, loc_view, committed_size, - static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset); - global2local( - global_data, global_imag_ptr, loc_view, committed_size, - static_cast(i * committed_size) + global_ptr_offset, - i * fft_size + loc_offset + local_imag_offset); - } else if (direction == detail::transfer_direction::LOCAL_TO_GLOBAL) { - local2global(global_data, loc_view, global_ptr, committed_size, - i * fft_size + loc_offset, - global_ptr_offset + i * committed_size); - local2global(global_data, loc_view, global_imag_ptr, committed_size, - i * fft_size + loc_offset + local_imag_offset, - global_ptr_offset + i * committed_size); - } + local2global(global_data, loc_view, global_ptr, committed_size, + i * fft_size + loc_offset, + global_ptr_offset + i * committed_size); + local2global(global_data, loc_view, global_imag_ptr, committed_size, + i * fft_size + loc_offset + local_imag_offset, + global_ptr_offset + i * committed_size); + } + } + + sycl::group_barrier(sg); +} + +template +PORTFFT_INLINE void subgroup_impl_bluestein_global2local_packed_copy( + const TIn* global_ptr, const TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, + IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, sycl::sub_group& sg, + complex_storage storage, detail::global_data_struct<1>& global_data) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg; i++) { + global2local( + global_data, global_ptr, loc_view, 2 * committed_size, + static_cast(2 * i * committed_size) + global_ptr_offset, 2 * i * fft_size + loc_offset); + } + } else { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg; i++) { + global2local( + global_data, global_ptr, loc_view, committed_size, + static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset); + global2local( + global_data, global_imag_ptr, loc_view, committed_size, + static_cast(i * committed_size) + global_ptr_offset, + i * fft_size + loc_offset + local_imag_offset); } } diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index fd360b8d..f466f581 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -197,15 +197,14 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "loading transposed data from global to local memory"); // load / store in a transposed manner if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local2global_strided_copy( - const_cast(input), loc_view, {2 * n_transforms, static_cast(1)}, - {2 * max_num_batches_local_mem, 1}, 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data, - detail::transfer_direction::GLOBAL_TO_LOCAL); + subgroup_impl_global2local_strided_copy( + input, loc_view, {2 * n_transforms, static_cast(1)}, + {2 * max_num_batches_local_mem, 1}, 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data); } else { - subgroup_impl_local2global_strided_copy( - const_cast(input), const_cast(input_imag), loc_view, {n_transforms, static_cast(1)}, + subgroup_impl_global2local_strided_copy( + input, input_imag, loc_view, {n_transforms, static_cast(1)}, {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, - global_data, detail::transfer_direction::GLOBAL_TO_LOCAL); + global_data); } sycl::group_barrier(global_data.it.get_group()); @@ -312,13 +311,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (storage == complex_storage::INTERLEAVED_COMPLEX) { subgroup_impl_local2global_strided_copy( output, loc_view, {output_stride * 2, output_distance * 2, 1}, {max_num_batches_local_mem * 2, 2, 1}, - i * output_distance * 2, 0, {committed_length, num_batches_in_local_mem, 2}, global_data, - detail::transfer_direction::LOCAL_TO_GLOBAL); + i * output_distance * 2, 0, {committed_length, num_batches_in_local_mem, 2}, global_data); } else { subgroup_impl_local2global_strided_copy( output, output_imag, loc_view, {output_stride, output_distance}, {max_num_batches_local_mem, 1}, - i * output_distance, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data, - detail::transfer_direction::LOCAL_TO_GLOBAL); + i * output_distance, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); } } else { global_data.log_message_global( @@ -326,13 +323,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (storage == complex_storage::INTERLEAVED_COMPLEX) { subgroup_impl_local2global_strided_copy( output, loc_view, {2 * n_transforms, static_cast(1)}, {2 * max_num_batches_local_mem, 1}, - 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data, - detail::transfer_direction::LOCAL_TO_GLOBAL); + 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data); } else { subgroup_impl_local2global_strided_copy( output, output_imag, loc_view, {n_transforms, static_cast(1)}, - {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, - global_data, detail::transfer_direction::LOCAL_TO_GLOBAL); + {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); } } } @@ -365,17 +360,15 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } else { if (storage == complex_storage::INTERLEAVED_COMPLEX) { global_data.log_message_global(__func__, "storing data from unpacked global memory to local"); - subgroup_impl_local2global_strided_copy( - const_cast(input), loc_view, {input_distance * 2, input_stride * 2, 1}, + subgroup_impl_global2local_strided_copy( + input, loc_view, {input_distance * 2, input_stride * 2, 1}, {committed_length * 2, 2, 1}, input_distance * 2 * (i - static_cast(id_of_fft_in_sg)), - local_offset, {n_ffts_worked_on_by_sg, committed_length, 2}, global_data, - detail::transfer_direction::GLOBAL_TO_LOCAL); + local_offset, {n_ffts_worked_on_by_sg, committed_length, 2}, global_data); } else { - subgroup_impl_local2global_strided_copy( - const_cast(input), const_cast(input_imag), loc_view, {input_distance, input_stride}, + subgroup_impl_global2local_strided_copy( + input, input_imag, loc_view, {input_distance, input_stride}, {committed_length, 1}, input_distance * (i - static_cast(id_of_fft_in_sg)), local_offset, - local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data, - detail::transfer_direction::GLOBAL_TO_LOCAL); + local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data); } } } else { @@ -384,13 +377,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag ? 2 * committed_length * (i - static_cast(id_of_fft_in_sg)) : committed_length * (i - static_cast(id_of_fft_in_sg)); auto loc_view_offset = storage == complex_storage::INTERLEAVED_COMPLEX - ? 2 * factor_sg * factor_wi * subgroup_id - : factor_sg * factor_wi * subgroup_id; + ? 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg + : factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; - subgroup_impl_bluestein_localglobal_packed_copy( - const_cast(input), const_cast(input_imag), loc_view, committed_length, factor_sg * factor_wi, + subgroup_impl_bluestein_global2local_packed_copy( + input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, - detail::transfer_direction::GLOBAL_TO_LOCAL, global_data); + global_data); } else { // TODO: Bluestein Strided copy } @@ -534,14 +527,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "storing data from local to unpacked global memory"); subgroup_impl_local2global_strided_copy( output, loc_view, {output_distance * 2, output_stride * 2, 1}, {committed_length * 2, 2, 1}, - global_output_offset, local_offset, {n_ffts_worked_on_by_sg, fft_size, 2}, global_data, - detail::transfer_direction::LOCAL_TO_GLOBAL); + global_output_offset, local_offset, {n_ffts_worked_on_by_sg, fft_size, 2}, global_data); } else { const IdxGlobal global_output_offset = output_distance * (i - static_cast(id_of_fft_in_sg)); subgroup_impl_local2global_strided_copy( output, output_imag, loc_view, {output_distance, output_stride}, {committed_length, 1}, - global_output_offset, local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, - global_data, detail::transfer_direction::LOCAL_TO_GLOBAL); + global_output_offset, local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data); } } } else { @@ -553,10 +544,10 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag ? 2 * factor_sg * factor_wi * subgroup_id : factor_sg * factor_wi * subgroup_id; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; - subgroup_impl_bluestein_localglobal_packed_copy( + subgroup_impl_bluestein_local2global_packed_copy( output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, - detail::transfer_direction::LOCAL_TO_GLOBAL, global_data); + global_data); } else { // TODO: Blustein Strided Copy } @@ -672,9 +663,14 @@ struct committed_descriptor_impl::run_kernel_struct(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, &in_imag_acc_or_usm[0] + input_offset, - &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc_twiddles[0], + &out_imag_acc_or_usm[0] + output_offset, loc_ptr, &loc_twiddles[0], n_transforms, twiddles, global_data, kh, twiddles + 2 * fft_size, twiddles + 4 * fft_size); } From 5b89d107f8d7a78af67eebff02211e9ecd32b495 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Mar 2024 06:21:13 +0000 Subject: [PATCH 06/22] format --- .../dispatcher/subgroup_dispatcher.hpp | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index f466f581..d23e5cb4 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -198,13 +198,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag // load / store in a transposed manner if (storage == complex_storage::INTERLEAVED_COMPLEX) { subgroup_impl_global2local_strided_copy( - input, loc_view, {2 * n_transforms, static_cast(1)}, - {2 * max_num_batches_local_mem, 1}, 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data); + input, loc_view, {2 * n_transforms, static_cast(1)}, {2 * max_num_batches_local_mem, 1}, 2 * i, + 0, {committed_length, 2 * num_batches_in_local_mem}, global_data); } else { subgroup_impl_global2local_strided_copy( - input, input_imag, loc_view, {n_transforms, static_cast(1)}, - {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, - global_data); + input, input_imag, loc_view, {n_transforms, static_cast(1)}, {max_num_batches_local_mem, 1}, i, + 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); } sycl::group_barrier(global_data.it.get_group()); @@ -327,7 +326,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } else { subgroup_impl_local2global_strided_copy( output, output_imag, loc_view, {n_transforms, static_cast(1)}, - {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); + {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, + global_data); } } } @@ -361,14 +361,14 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (storage == complex_storage::INTERLEAVED_COMPLEX) { global_data.log_message_global(__func__, "storing data from unpacked global memory to local"); subgroup_impl_global2local_strided_copy( - input, loc_view, {input_distance * 2, input_stride * 2, 1}, - {committed_length * 2, 2, 1}, input_distance * 2 * (i - static_cast(id_of_fft_in_sg)), - local_offset, {n_ffts_worked_on_by_sg, committed_length, 2}, global_data); + input, loc_view, {input_distance * 2, input_stride * 2, 1}, {committed_length * 2, 2, 1}, + input_distance * 2 * (i - static_cast(id_of_fft_in_sg)), local_offset, + {n_ffts_worked_on_by_sg, committed_length, 2}, global_data); } else { subgroup_impl_global2local_strided_copy( - input, input_imag, loc_view, {input_distance, input_stride}, - {committed_length, 1}, input_distance * (i - static_cast(id_of_fft_in_sg)), local_offset, - local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data); + input, input_imag, loc_view, {input_distance, input_stride}, {committed_length, 1}, + input_distance * (i - static_cast(id_of_fft_in_sg)), local_offset, local_imag_offset, + {n_ffts_worked_on_by_sg, committed_length}, global_data); } } } else { @@ -381,9 +381,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag : factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; subgroup_impl_bluestein_global2local_packed_copy( - input, input_imag, loc_view, committed_length, factor_sg * factor_wi, - global_ptr_offset, loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, - global_data); + input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, + loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, global_data); } else { // TODO: Bluestein Strided copy } @@ -532,7 +531,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag const IdxGlobal global_output_offset = output_distance * (i - static_cast(id_of_fft_in_sg)); subgroup_impl_local2global_strided_copy( output, output_imag, loc_view, {output_distance, output_stride}, {committed_length, 1}, - global_output_offset, local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data); + global_output_offset, local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, + global_data); } } } else { @@ -546,8 +546,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; subgroup_impl_bluestein_local2global_packed_copy( output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, - loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, - global_data); + loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, global_data); } else { // TODO: Blustein Strided Copy } @@ -664,7 +663,8 @@ struct committed_descriptor_impl::run_kernel_struct Date: Tue, 12 Mar 2024 15:35:52 +0000 Subject: [PATCH 07/22] not copy in between an aligned pointer --- .../dispatcher/subgroup_dispatcher.hpp | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index d23e5cb4..51572bb6 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -32,6 +32,8 @@ #include "portfft/enums.hpp" #include "portfft/specialization_constant.hpp" +#include + namespace portfft { namespace detail { /** @@ -578,27 +580,24 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn kernel_data.length * 2); Scalar* res = sycl::aligned_alloc_device( alignof(sycl::vec), twiddles_alloc_size, desc.queue); - sycl::range<2> kernel_range({static_cast(factor_sg), static_cast(factor_wi)}); - desc.queue.submit([&](sycl::handler& cgh) { - PORTFFT_LOG_TRACE("Launching twiddle calculation kernel for subgroup implementation with global size", factor_sg, - factor_wi); - cgh.parallel_for(kernel_range, [=](sycl::item<2> it) { - Idx n = static_cast(it.get_id(0)); - Idx k = static_cast(it.get_id(1)); - sg_calc_twiddles(factor_sg, factor_wi, n, k, res); - }); - }); + std::vector host_twiddles(twiddles_alloc_size); + + for (Idx i = 0; i < factor_sg; i++) { + for (Idx j = 0; j < factor_wi; j++) { + double theta = -2 * M_PI * static_cast(i * j) / static_cast(factor_wi * factor_sg); + auto twiddle = std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + host_twiddles[static_cast(j * factor_sg + i)] = twiddle.real(); + host_twiddles[static_cast((j + factor_wi) * factor_sg + i)] = twiddle.imag(); + } + } if (dimension_data.is_prime) { - std::vector bluestein_twiddles_host_ptr(4 * dimension_data.length, 0); - detail::populate_bluestein_input_modifiers(bluestein_twiddles_host_ptr.data(), dimension_data.committed_length, - dimension_data.length); - detail::populate_fft_chirp_signal(bluestein_twiddles_host_ptr.data() + 2 * dimension_data.length, + detail::populate_bluestein_input_modifiers(host_twiddles.data() + 2 * factor_sg * factor_wi, + dimension_data.committed_length, dimension_data.length); + detail::populate_fft_chirp_signal(host_twiddles.data() + 4 * factor_sg * factor_wi, dimension_data.committed_length, dimension_data.length); - desc.queue.copy(bluestein_twiddles_host_ptr.data(), res + 2 * dimension_data.length, 4 * dimension_data.length) - .wait(); } - desc.queue.wait(); // waiting once here can be better than depending on the event - // for all future calls to compute + + desc.queue.copy(host_twiddles.data(), res, twiddles_alloc_size).wait(); return res; } }; From 5afe834b3e37d0b975297ce941b911feb57a62e4 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Mar 2024 21:59:41 +0000 Subject: [PATCH 08/22] prevent OOB read/writes in packed format --- src/portfft/dispatcher/subgroup_dispatcher.hpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 51572bb6..d0b7b9cc 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -382,9 +382,15 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag ? 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg : factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; + auto n_ffts_to_copy = [=]() { + if (i + static_cast(n_ffts_worked_on_by_sg) < n_transforms) { + return n_ffts_worked_on_by_sg; + } + return static_cast(n_transforms - i); + }(); subgroup_impl_bluestein_global2local_packed_copy( input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, - loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, global_data); + loc_view_imag_offset, n_ffts_to_copy, global_data.sg, storage, global_data); } else { // TODO: Bluestein Strided copy } @@ -546,9 +552,15 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag ? 2 * factor_sg * factor_wi * subgroup_id : factor_sg * factor_wi * subgroup_id; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; + auto n_ffts_to_copy = [=]() { + if (i + static_cast(n_ffts_worked_on_by_sg) < n_transforms) { + return n_ffts_worked_on_by_sg; + } + return static_cast(n_transforms - i); + }(); subgroup_impl_bluestein_local2global_packed_copy( output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, - loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, global_data.sg, storage, global_data); + loc_view_offset, loc_view_imag_offset, n_ffts_to_copy, global_data.sg, storage, global_data); } else { // TODO: Blustein Strided Copy } From 8a83350cf57909c47a3143d57dd9921587f9ffbc Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 13 Mar 2024 07:17:08 +0000 Subject: [PATCH 09/22] prevent OOB read writes in PACKED bluestein condition --- src/portfft/common/subgroup.hpp | 16 ++++++++-------- .../dispatcher/subgroup_dispatcher.hpp | 19 +++++-------------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index 9e05ac9d..7ec31b29 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -408,18 +408,18 @@ PORTFFT_INLINE void subgroup_impl_local_private_copy( template PORTFFT_INLINE void subgroup_impl_bluestein_local2global_packed_copy( TIn* global_ptr, TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, - IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, sycl::sub_group& sg, - complex_storage storage, detail::global_data_struct<1>& global_data) { + IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal batch_start, + IdxGlobal n_transforms, sycl::sub_group& sg, complex_storage storage, detail::global_data_struct<1>& global_data) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg; i++) { + for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { local2global(global_data, loc_view, global_ptr, 2 * committed_size, 2 * i * fft_size + loc_offset, global_ptr_offset + 2 * i * committed_size); } } else { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg; i++) { + for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { local2global(global_data, loc_view, global_ptr, committed_size, i * fft_size + loc_offset, global_ptr_offset + i * committed_size); @@ -435,18 +435,18 @@ PORTFFT_INLINE void subgroup_impl_bluestein_local2global_packed_copy( template PORTFFT_INLINE void subgroup_impl_bluestein_global2local_packed_copy( const TIn* global_ptr, const TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, - IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, sycl::sub_group& sg, - complex_storage storage, detail::global_data_struct<1>& global_data) { + IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal batch_start, + IdxGlobal n_transforms, sycl::sub_group& sg, complex_storage storage, detail::global_data_struct<1>& global_data) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg; i++) { + for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { global2local( global_data, global_ptr, loc_view, 2 * committed_size, static_cast(2 * i * committed_size) + global_ptr_offset, 2 * i * fft_size + loc_offset); } } else { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg; i++) { + for (Idx i = 0; i < n_ffts_in_sg && (i + batch_start < n_transforms); i++) { global2local( global_data, global_ptr, loc_view, committed_size, static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset); diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index d0b7b9cc..63420e71 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -382,15 +382,10 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag ? 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg : factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; - auto n_ffts_to_copy = [=]() { - if (i + static_cast(n_ffts_worked_on_by_sg) < n_transforms) { - return n_ffts_worked_on_by_sg; - } - return static_cast(n_transforms - i); - }(); + subgroup_impl_bluestein_global2local_packed_copy( input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, - loc_view_imag_offset, n_ffts_to_copy, global_data.sg, storage, global_data); + loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data.sg, storage, global_data); } else { // TODO: Bluestein Strided copy } @@ -552,15 +547,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag ? 2 * factor_sg * factor_wi * subgroup_id : factor_sg * factor_wi * subgroup_id; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; - auto n_ffts_to_copy = [=]() { - if (i + static_cast(n_ffts_worked_on_by_sg) < n_transforms) { - return n_ffts_worked_on_by_sg; - } - return static_cast(n_transforms - i); - }(); + subgroup_impl_bluestein_local2global_packed_copy( output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, - loc_view_offset, loc_view_imag_offset, n_ffts_to_copy, global_data.sg, storage, global_data); + loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data.sg, storage, + global_data); } else { // TODO: Blustein Strided Copy } From 011e78071f07c9af037b023d058f424e222c6867 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 13 Mar 2024 22:41:08 +0000 Subject: [PATCH 10/22] refactor --- src/portfft/common/subgroup.hpp | 189 +++------------- src/portfft/common/transfers.hpp | 104 +++++++++ src/portfft/defines.hpp | 8 + .../dispatcher/subgroup_dispatcher.hpp | 211 ++++++++---------- 4 files changed, 240 insertions(+), 272 deletions(-) diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index 7ec31b29..6e22b30d 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -311,153 +311,29 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles sg_twiddles[(k + factor_wi) * factor_sg + n] = twiddle.imag(); } -template -PORTFFT_INLINE void subgroup_impl_global2local_strided_copy(const T* global_ptr, LocView& loc_view, - std::array strides_global, - std::array strides_local, - IdxGlobal offset_global, Idx offset_local, - std::array copy_strides, - detail::global_data_struct<1> global_data) { - detail::md_view global_md_view{global_ptr, strides_global, offset_global}; - detail::md_view local_md_view{loc_view, strides_local, offset_local}; - copy_group(global_data, global_md_view, local_md_view, copy_strides); -} - -template -PORTFFT_INLINE void subgroup_impl_local2global_strided_copy(T* global_ptr, LocView& loc_view, - std::array strides_global, - std::array strides_local, - IdxGlobal offset_global, Idx offset_local, - std::array copy_strides, - detail::global_data_struct<1> global_data) { - detail::md_view global_md_view{global_ptr, strides_global, offset_global}; - detail::md_view local_md_view{loc_view, strides_local, offset_local}; - copy_group(global_data, local_md_view, global_md_view, copy_strides); -} - -template -PORTFFT_INLINE void subgroup_impl_local2global_strided_copy( - T* global_ptr, T* global_imag_ptr, LocView& loc_view, std::array strides_global, - std::array strides_local, IdxGlobal offset_global, Idx local_offset, Idx local_imag_offset, - std::array copy_strides, detail::global_data_struct<1> global_data) { - detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; - detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; - detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; - detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; - copy_group(global_data, local_md_real_view, global_md_real_view, copy_strides); - copy_group(global_data, local_md_imag_view, global_md_imag_view, copy_strides); -} - -template -PORTFFT_INLINE void subgroup_impl_global2local_strided_copy( - const T* global_ptr, const T* global_imag_ptr, LocView& loc_view, std::array strides_global, - std::array strides_local, IdxGlobal offset_global, Idx local_offset, Idx local_imag_offset, - std::array copy_strides, detail::global_data_struct<1> global_data) { - detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; - detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; - detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; - detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; - copy_group(global_data, global_md_real_view, local_md_real_view, copy_strides); - copy_group(global_data, global_md_imag_view, local_md_imag_view, copy_strides); -} - -template -PORTFFT_INLINE void subgroup_impl_local_private_copy( - PtrView& ptr_view, PtrView& ptr_imag_view, T* priv, - std::array, 2> ptr_view_strides_offsets, - std::array, 2> priv_view_strides_offsets, - std::array, 2> ptr_imag_view_strides_offsets, - std::array, 2> priv_imag_view_strides_offsets, Idx num_elements_to_copy, - detail::global_data_struct<1> global_data, detail::transfer_direction direction) { - detail::strided_view ptr_strided_real_view{ptr_view, std::get<0>(ptr_view_strides_offsets), - std::get<1>(ptr_view_strides_offsets)}; - detail::strided_view ptr_strided_imag_view{ptr_imag_view, std::get<0>(ptr_imag_view_strides_offsets), - std::get<1>(ptr_imag_view_strides_offsets)}; - detail::strided_view priv_strided_real_view{priv, std::get<0>(priv_view_strides_offsets), - std::get<1>(priv_view_strides_offsets)}; - detail::strided_view priv_strided_imag_view{priv, std::get<0>(priv_imag_view_strides_offsets), - std::get<1>(priv_imag_view_strides_offsets)}; - if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { - copy_wi(global_data, ptr_strided_real_view, priv_strided_real_view, num_elements_to_copy); - copy_wi(global_data, ptr_strided_imag_view, priv_strided_imag_view, num_elements_to_copy); - } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || - direction == detail::transfer_direction::PRIVATE_TO_GLOBAL) { - copy_wi(global_data, priv_strided_real_view, ptr_strided_real_view, num_elements_to_copy); - copy_wi(global_data, priv_strided_imag_view, ptr_strided_imag_view, num_elements_to_copy); - } -} - -template -PORTFFT_INLINE void subgroup_impl_local_private_copy( - PtrView& ptr_view, T* priv, std::array, 2> ptr_view_strides_offsets, - Idx num_elements_to_copy, detail::global_data_struct<1> global_data, detail::transfer_direction direction) { - detail::strided_view ptr_strided_view{ptr_view, std::get<0>(ptr_view_strides_offsets), - std::get<1>(ptr_view_strides_offsets)}; - if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { - copy_wi<2>(global_data, ptr_strided_view, priv, num_elements_to_copy); - } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || - direction == detail::transfer_direction::PRIVATE_TO_GLOBAL) { - copy_wi<2>(global_data, priv, ptr_strided_view, num_elements_to_copy); - } -} - -template -PORTFFT_INLINE void subgroup_impl_bluestein_local2global_packed_copy( - TIn* global_ptr, TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, - IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal batch_start, - IdxGlobal n_transforms, sycl::sub_group& sg, complex_storage storage, detail::global_data_struct<1>& global_data) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { - local2global(global_data, loc_view, global_ptr, 2 * committed_size, - 2 * i * fft_size + loc_offset, - global_ptr_offset + 2 * i * committed_size); - } - } else { - PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { - local2global(global_data, loc_view, global_ptr, committed_size, - i * fft_size + loc_offset, - global_ptr_offset + i * committed_size); - local2global(global_data, loc_view, global_imag_ptr, committed_size, - i * fft_size + loc_offset + local_imag_offset, - global_ptr_offset + i * committed_size); - } +template +PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( + TIn global_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, + Idx n_ffts_in_sg, IdxGlobal batch_start, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { + local_global_packed_copy( + global_ptr, loc_view, global_ptr_offset + static_cast(2 * i * committed_size), + 2 * i * fft_size + loc_offset, 2 * committed_size, global_data); } - - sycl::group_barrier(sg); } -template -PORTFFT_INLINE void subgroup_impl_bluestein_global2local_packed_copy( - const TIn* global_ptr, const TIn* global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, +template +PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( + TIn global_ptr, TIn global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal batch_start, - IdxGlobal n_transforms, sycl::sub_group& sg, complex_storage storage, detail::global_data_struct<1>& global_data) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { - global2local( - global_data, global_ptr, loc_view, 2 * committed_size, - static_cast(2 * i * committed_size) + global_ptr_offset, 2 * i * fft_size + loc_offset); - } - } else { - PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && (i + batch_start < n_transforms); i++) { - global2local( - global_data, global_ptr, loc_view, committed_size, - static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset); - global2local( - global_data, global_imag_ptr, loc_view, committed_size, - static_cast(i * committed_size) + global_ptr_offset, - i * fft_size + loc_offset + local_imag_offset); - } + IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg && (i + batch_start < n_transforms); i++) { + local_global_packed_copy( + global_ptr, global_imag_ptr, loc_view, static_cast(i * committed_size) + global_ptr_offset, + i * fft_size + loc_offset, local_imag_offset, committed_size, global_data); } - - sycl::group_barrier(sg); } /** @@ -552,15 +428,14 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved(T* priv, T* priv_scratch, Loc if (wi_working) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local_private_copy<2, Idx>( + local_private_strided_copy<2, Idx>( loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } else { - subgroup_impl_local_private_copy<2, 1, Idx>( + local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, - {{{2}, {0}}}, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, - {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } } @@ -568,16 +443,15 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved(T* priv, T* priv_scratch, Loc if (wi_working) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; - subgroup_impl_local_private_copy<1, Idx>( + local_private_strided_copy<1, Idx>( loc_view, priv, {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } else { - subgroup_impl_local_private_copy<2, 1, Idx>( + local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, - {{{2}, {0}}}, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, - {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } } @@ -608,8 +482,8 @@ void sg_bluestein(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& if (wi_working) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local_private_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_offset_store_view}}}, factor_wi, - global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_offset_store_view}}}, factor_wi, + global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } else { detail::strided_view priv_real_view{priv, 2}; detail::strided_view priv_imag_view{priv, 2, 1}; @@ -624,13 +498,12 @@ void sg_bluestein(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& if (wi_working) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local_private_copy<1, Idx>(loc_view, priv, {{{1}, {loc_offset_load_view}}}, factor_wi, global_data, - detail::transfer_direction::LOCAL_TO_PRIVATE); + local_private_strided_copy<1, Idx>(loc_view, priv, {{{1}, {loc_offset_load_view}}}, factor_wi, global_data, + detail::transfer_direction::LOCAL_TO_PRIVATE); } else { - subgroup_impl_local_private_copy<1, 1, Idx>(loc_view, loc_view, priv, {{{1}, {loc_offset_load_view}}}, - {{{2}, {0}}}, {{{1}, {loc_offset_load_view + local_imag_offset}}}, - {{{2}, {1}}}, factor_wi, global_data, - detail::transfer_direction::LOCAL_TO_PRIVATE); + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{1}, {loc_offset_load_view}}}, + {{{1}, {loc_offset_load_view + local_imag_offset}}}, factor_wi, global_data, + detail::transfer_direction::LOCAL_TO_PRIVATE); } } diff --git a/src/portfft/common/transfers.hpp b/src/portfft/common/transfers.hpp index 99ce928d..ab7dbae1 100644 --- a/src/portfft/common/transfers.hpp +++ b/src/portfft/common/transfers.hpp @@ -488,6 +488,110 @@ PORTFFT_INLINE void local2global(detail::global_data_struct<1> global_data, Loca global_data, global, local, total_num_elems, global_offset, local_offset); } +template +PORTFFT_INLINE void local_global_packed_copy(T* global_ptr, LocView& loc_view, IdxGlobal global_offset, + Idx local_offset, Idx n_elements_to_copy, + detail::global_data_struct<1>& global_data) { + if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global2local(global_data, global_ptr, loc_view, n_elements_to_copy, global_offset, + local_offset); + } else { + local2global(global_data, loc_view, global_ptr, n_elements_to_copy, local_offset, + global_offset); + } +} + +template +PORTFFT_INLINE void local_global_packed_copy(T* global_ptr, T* global_imag_ptr, LocView& loc_view, + IdxGlobal global_offset, Idx local_offset, Idx local_imag_offset, + Idx n_elements_to_copy, detail::global_data_struct<1>& global_data) { + if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global2local(global_data, global_ptr, loc_view, n_elements_to_copy, global_offset, + local_offset); + global2local(global_data, global_imag_ptr, loc_view, n_elements_to_copy, global_offset, + local_offset + local_imag_offset); + } else { + local2global(global_data, loc_view, global_ptr, n_elements_to_copy, local_offset, + global_offset); + local2global(global_data, loc_view, global_imag_ptr, n_elements_to_copy, + local_offset + local_imag_offset, global_offset); + } +} + +template +PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, LocView& loc_view, + std::array strides_global, + std::array strides_local, IdxGlobal offset_global, + Idx offset_local, std::array copy_lengths, + detail::global_data_struct<1> global_data) { + detail::md_view global_md_view{global_ptr, strides_global, offset_global}; + detail::md_view local_md_view{loc_view, strides_local, offset_local}; + if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + copy_group(global_data, global_md_view, local_md_view, copy_lengths); + } else { + copy_group(global_data, local_md_view, global_md_view, copy_lengths); + } +} + +template +PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, T* global_imag_ptr, LocView& loc_view, + std::array strides_global, + std::array strides_local, IdxGlobal offset_global, + Idx local_offset, Idx local_imag_offset, + std::array copy_lengths, + detail::global_data_struct<1> global_data) { + detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; + detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; + detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; + detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; + if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + copy_group(global_data, global_md_real_view, local_md_real_view, copy_lengths); + copy_group(global_data, global_md_imag_view, local_md_imag_view, copy_lengths); + } else { + copy_group(global_data, local_md_real_view, global_md_real_view, copy_lengths); + copy_group(global_data, local_md_imag_view, global_md_imag_view, copy_lengths); + } +} + +template +PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, T* priv, + so_array ptr_view_strides_offsets, + Idx num_elements_to_copy, detail::global_data_struct<1> global_data, + detail::transfer_direction direction) { + detail::strided_view ptr_strided_view{ptr_view, std::get<0>(ptr_view_strides_offsets), + std::get<1>(ptr_view_strides_offsets)}; + if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { + copy_wi<2>(global_data, ptr_strided_view, priv, num_elements_to_copy); + } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || + direction == detail::transfer_direction::PRIVATE_TO_GLOBAL) { + copy_wi<2>(global_data, priv, ptr_strided_view, num_elements_to_copy); + } +} + +template +PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, PtrView& ptr_imag_view, T* priv, + so_array ptr_view_strides_offsets, + so_array ptr_imag_view_strides_offsets, + Idx num_elements_to_copy, detail::global_data_struct<1> global_data, + detail::transfer_direction direction) { + detail::strided_view ptr_strided_real_view{ptr_view, std::get<0>(ptr_view_strides_offsets), + std::get<1>(ptr_view_strides_offsets)}; + detail::strided_view ptr_strided_imag_view{ptr_imag_view, std::get<0>(ptr_imag_view_strides_offsets), + std::get<1>(ptr_imag_view_strides_offsets)}; + detail::strided_view priv_strided_real_view{priv, 2}; + detail::strided_view priv_strided_imag_view{priv, 2, 1}; + if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { + copy_wi(global_data, ptr_strided_real_view, priv_strided_real_view, num_elements_to_copy); + copy_wi(global_data, ptr_strided_imag_view, priv_strided_imag_view, num_elements_to_copy); + } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || + direction == detail::transfer_direction::PRIVATE_TO_GLOBAL) { + copy_wi(global_data, priv_strided_real_view, ptr_strided_real_view, num_elements_to_copy); + copy_wi(global_data, priv_strided_imag_view, ptr_strided_imag_view, num_elements_to_copy); + } +} + } // namespace portfft #endif diff --git a/src/portfft/defines.hpp b/src/portfft/defines.hpp index 9fcd41fd..8af51ba0 100644 --- a/src/portfft/defines.hpp +++ b/src/portfft/defines.hpp @@ -48,6 +48,14 @@ namespace portfft { using Idx = std::int32_t; using IdxGlobal = std::int64_t; +/** + * An array of 2 arrays containing N elements of Type, containing strides (s) and offset (o) for a view + * @tparam Type Type of elements + * @tparam N Number of elements in each of the two arrays + */ +template +using so_array = std::array, 2>; + } // namespace portfft #endif diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 63420e71..4ff6edb2 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -199,11 +199,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "loading transposed data from global to local memory"); // load / store in a transposed manner if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_global2local_strided_copy( + local_global_strided_copy( input, loc_view, {2 * n_transforms, static_cast(1)}, {2 * max_num_batches_local_mem, 1}, 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data); } else { - subgroup_impl_global2local_strided_copy( + local_global_strided_copy( input, input_imag, loc_view, {n_transforms, static_cast(1)}, {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); } @@ -221,16 +221,16 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "loading batch_interleaved data from local to private memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; - subgroup_impl_local_private_copy<1, Idx>( + local_private_strided_copy<1, Idx>( loc_view, priv, {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } else { - subgroup_impl_local_private_copy<2, 1, Idx>( + local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, - {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, {{{2}, {0}}}, + {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, - {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } @@ -257,25 +257,18 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global( __func__, "storing transposed data from private to packed global memory (SubgroupSize == FactorSG)"); // Store directly from registers for fully coalesced accesses + IdxGlobal output_offset = static_cast(i + static_cast(fft_idx_in_local)) * + static_cast(2 * fft_size) + + static_cast(2 * id_of_wi_in_fft); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local_private_copy<1, IdxGlobal>( - output, priv, - {{{static_cast(factor_sg)}, - {static_cast(i + static_cast(fft_idx_in_local)) * - static_cast(2 * fft_size) + - static_cast(2 * id_of_wi_in_fft)}}}, - factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_GLOBAL); + so_array output_stride_offset{{{static_cast(factor_sg)}, {output_offset}}}; + local_private_strided_copy<1, IdxGlobal>(output, priv, output_stride_offset, factor_wi, global_data, + detail::transfer_direction::PRIVATE_TO_GLOBAL); } else { - subgroup_impl_local_private_copy<1, 1, IdxGlobal>( - output, output_imag, priv, - {{{static_cast(factor_sg)}, - {(i + static_cast(fft_idx_in_local)) * static_cast(fft_size) + - static_cast(id_of_wi_in_fft)}}}, - {{{2}, {0}}}, - {{{static_cast(factor_sg)}, - {(i + static_cast(fft_idx_in_local)) * static_cast(fft_size) + - static_cast(id_of_wi_in_fft)}}}, - {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_GLOBAL); + so_array output_stride_offset{{{static_cast(factor_sg)}, {output_offset / 2}}}; + local_private_strided_copy<1, IdxGlobal>(output, output_imag, priv, output_stride_offset, + output_stride_offset, factor_wi, global_data, + detail::transfer_direction::PRIVATE_TO_GLOBAL); } } } else { @@ -286,19 +279,20 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag "FactorSG or not packed output layout)"); // Store back to local memory only if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local_private_copy<2, Idx>( + local_private_strided_copy<2, Idx>( loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } else { - subgroup_impl_local_private_copy<2, 1, Idx>( + local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, - {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, {{{2}, {0}}}, + {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, - {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } } } + sycl::group_barrier(global_data.sg); } sycl::group_barrier(global_data.it.get_group()); if (!store_directly_from_private) { @@ -310,11 +304,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag "storing data from batch interleaved local memory to not batch interleaved " "global memory (SubgroupSize != FactorSG)"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local2global_strided_copy( + local_global_strided_copy( output, loc_view, {output_stride * 2, output_distance * 2, 1}, {max_num_batches_local_mem * 2, 2, 1}, i * output_distance * 2, 0, {committed_length, num_batches_in_local_mem, 2}, global_data); } else { - subgroup_impl_local2global_strided_copy( + local_global_strided_copy( output, output_imag, loc_view, {output_stride, output_distance}, {max_num_batches_local_mem, 1}, i * output_distance, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); } @@ -322,11 +316,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global( __func__, "storing data from batch interleaved local memory to batch interleaved global memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local2global_strided_copy( + local_global_strided_copy( output, loc_view, {2 * n_transforms, static_cast(1)}, {2 * max_num_batches_local_mem, 1}, 2 * i, 0, {committed_length, 2 * num_batches_in_local_mem}, global_data); } else { - subgroup_impl_local2global_strided_copy( + local_global_strided_copy( output, output_imag, loc_view, {n_transforms, static_cast(1)}, {max_num_batches_local_mem, 1}, i, 0, local_imag_offset, {committed_length, num_batches_in_local_mem}, global_data); @@ -344,30 +338,26 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "loading non-transposed data from global to local memory"); if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { if (is_input_packed) { + IdxGlobal global_ptr_offset = + static_cast(n_io_reals_per_fft) * (i - static_cast(id_of_fft_in_sg)); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - global2local( - global_data, input, loc_view, n_ffts_worked_on_by_sg * n_reals_per_fft, - static_cast(n_reals_per_fft) * (i - static_cast(id_of_fft_in_sg)), - subgroup_id * n_reals_per_sg); + local_global_packed_copy( + input, loc_view, global_ptr_offset, subgroup_id * n_reals_per_sg, + n_ffts_worked_on_by_sg * n_reals_per_fft, global_data); } else { - global2local( - global_data, input, loc_view, n_ffts_worked_on_by_sg * fft_size, - static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - subgroup_id * n_cplx_per_sg); - global2local( - global_data, input_imag, loc_view, n_ffts_worked_on_by_sg * fft_size, - static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - local_imag_offset + subgroup_id * n_cplx_per_sg); + local_global_packed_copy( + input, input_imag, loc_view, global_ptr_offset, subgroup_id * n_cplx_per_sg, local_imag_offset, + n_ffts_worked_on_by_sg * fft_size, global_data); } } else { if (storage == complex_storage::INTERLEAVED_COMPLEX) { global_data.log_message_global(__func__, "storing data from unpacked global memory to local"); - subgroup_impl_global2local_strided_copy( + local_global_strided_copy( input, loc_view, {input_distance * 2, input_stride * 2, 1}, {committed_length * 2, 2, 1}, input_distance * 2 * (i - static_cast(id_of_fft_in_sg)), local_offset, {n_ffts_worked_on_by_sg, committed_length, 2}, global_data); } else { - subgroup_impl_global2local_strided_copy( + local_global_strided_copy( input, input_imag, loc_view, {input_distance, input_stride}, {committed_length, 1}, input_distance * (i - static_cast(id_of_fft_in_sg)), local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data); @@ -375,17 +365,20 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } else { if (is_input_packed) { - auto global_ptr_offset = storage == complex_storage::INTERLEAVED_COMPLEX - ? 2 * committed_length * (i - static_cast(id_of_fft_in_sg)) - : committed_length * (i - static_cast(id_of_fft_in_sg)); - auto loc_view_offset = storage == complex_storage::INTERLEAVED_COMPLEX - ? 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg - : factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; - auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; - - subgroup_impl_bluestein_global2local_packed_copy( - input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, - loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data.sg, storage, global_data); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + auto global_ptr_offset = 2 * committed_length * (i - static_cast(id_of_fft_in_sg)); + auto local_view_offset = 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; + subgroup_impl_bluestein_local_global_packed_copy( + input, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, local_view_offset, + n_ffts_worked_on_by_sg, i, n_transforms, global_data); + } else { + auto global_ptr_offset = 2 * committed_length * (i - static_cast(id_of_fft_in_sg)); + auto local_view_offset = 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; + auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; + subgroup_impl_bluestein_local_global_packed_copy( + input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, + local_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data); + } } else { // TODO: Bluestein Strided copy } @@ -397,14 +390,14 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working) { global_data.log_message_global(__func__, "loading non-transposed data from local to private memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - subgroup_impl_local_private_copy<1, Idx>( + local_private_strided_copy<1, Idx>( loc_view, priv, {{{1}, {subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } else { - subgroup_impl_local_private_copy<1, 1, Idx>( + local_private_strided_copy<1, Idx>( loc_view, loc_view, priv, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi}}}, - {{{2}, {0}}}, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}}}, - {{{2}, {1}}}, factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}}}, factor_wi, + global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } @@ -446,16 +439,16 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag IdxGlobal output_offset = i * static_cast(n_reals_per_sg) + static_cast(id_of_fft_in_sg * n_reals_per_fft) + static_cast(id_of_wi_in_fft * 2); - subgroup_impl_local_private_copy<1, IdxGlobal>( + local_private_strided_copy<1, IdxGlobal>( output, priv, {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_GLOBAL); } else { IdxGlobal output_offset = i * static_cast(n_cplx_per_sg) + static_cast(id_of_fft_in_sg * fft_size) + static_cast(id_of_wi_in_fft); - subgroup_impl_local_private_copy<1, 1, IdxGlobal>( - output, output_imag, priv, {{{static_cast(factor_sg)}, {output_offset}}}, {{{2}, {0}}}, - {{{static_cast(factor_sg)}, {output_offset}}}, {{{2}, {1}}}, factor_wi, global_data, + local_private_strided_copy<1, IdxGlobal>( + output, output_imag, priv, {{{static_cast(factor_sg)}, {output_offset}}}, + {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_GLOBAL); } } @@ -463,19 +456,17 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working) { global_data.log_message_global(__func__, "Storing data from private to Global with batch interleaved layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view output_view{output, std::array{static_cast(factor_sg), n_transforms}, - std::array{static_cast(2 * id_of_wi_in_fft), 2 * i}}; - copy_wi<2>(global_data, priv, output_view, factor_wi); + local_private_strided_copy<2, IdxGlobal>(output, priv, + {{{static_cast(factor_sg), n_transforms}, + {static_cast(2 * id_of_wi_in_fft), 2 * i}}}, + factor_wi, global_data, + detail::transfer_direction::PRIVATE_TO_GLOBAL); } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::strided_view output_real_view{output, std::array{static_cast(factor_sg), n_transforms}, - std::array{static_cast(id_of_wi_in_fft), i}}; - detail::strided_view output_imag_view{output_imag, - std::array{static_cast(factor_sg), n_transforms}, - std::array{static_cast(id_of_wi_in_fft), i}}; - copy_wi(global_data, priv_real_view, output_real_view, factor_wi); - copy_wi(global_data, priv_imag_view, output_imag_view, factor_wi); + so_array global_stride_offset{ + {{static_cast(factor_sg), n_transforms}, {static_cast(id_of_wi_in_fft), i}}}; + local_private_strided_copy<2, IdxGlobal>(output, output_imag, priv, global_stride_offset, + global_stride_offset, factor_wi, global_data, + detail::transfer_direction::PRIVATE_TO_GLOBAL); } } } else { @@ -485,23 +476,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (storage == complex_storage::INTERLEAVED_COMPLEX) { Idx loc_view_offset = subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; - subgroup_impl_local_private_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, factor_wi, - global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, factor_wi, + global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::strided_view local_real_view{ - loc_view, factor_sg, subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft}; - detail::strided_view local_imag_view{ - loc_view, factor_sg, - subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft + local_imag_offset}; - copy_wi(global_data, priv_real_view, local_real_view, factor_wi); - copy_wi(global_data, priv_imag_view, local_imag_view, factor_wi); - // Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; - // subgroup_impl_local_private_copy<1, 1, Idx>( - // loc_view, loc_view, priv, {{{factor_sg}, {local_offset}}}, {{{2}, {0}}}, - // {{{factor_sg}, {loc_view_offset + local_imag_offset}}}, {{{2}, {1}}}, factor_wi, global_data, - // detail::transfer_direction::PRIVATE_TO_LOCAL); + Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, + {{{factor_sg}, {loc_view_offset + local_imag_offset}}}, factor_wi, + global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); } } sycl::group_barrier(global_data.sg); @@ -512,27 +493,25 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (is_output_packed) { const IdxGlobal global_output_offset = n_io_reals_per_fft * (i - static_cast(id_of_fft_in_sg)); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local2global(global_data, loc_view, output, - n_ffts_worked_on_by_sg * n_reals_per_fft, local_offset, - global_output_offset); + local_global_packed_copy( + output, loc_view, global_output_offset, local_offset, n_ffts_worked_on_by_sg * n_reals_per_fft, + global_data); } else { - local2global( - global_data, loc_view, output, n_ffts_worked_on_by_sg * fft_size, local_offset, global_output_offset); - local2global(global_data, loc_view, output_imag, - n_ffts_worked_on_by_sg * fft_size, - local_offset + local_imag_offset, global_output_offset); + local_global_packed_copy( + output, output_imag, loc_view, global_output_offset, local_offset, local_imag_offset, + n_ffts_worked_on_by_sg * fft_size, global_data); } } else { if (storage == complex_storage::INTERLEAVED_COMPLEX) { const IdxGlobal global_output_offset = 2 * output_distance * (i - static_cast(id_of_fft_in_sg)); global_data.log_message_global(__func__, "storing data from local to unpacked global memory"); - subgroup_impl_local2global_strided_copy( + local_global_strided_copy( output, loc_view, {output_distance * 2, output_stride * 2, 1}, {committed_length * 2, 2, 1}, global_output_offset, local_offset, {n_ffts_worked_on_by_sg, fft_size, 2}, global_data); } else { const IdxGlobal global_output_offset = output_distance * (i - static_cast(id_of_fft_in_sg)); - subgroup_impl_local2global_strided_copy( + local_global_strided_copy( output, output_imag, loc_view, {output_distance, output_stride}, {committed_length, 1}, global_output_offset, local_offset, local_imag_offset, {n_ffts_worked_on_by_sg, committed_length}, global_data); @@ -540,18 +519,22 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } else { if (is_output_packed) { - auto global_ptr_offset = storage == complex_storage::INTERLEAVED_COMPLEX - ? 2 * committed_length * (i - static_cast(id_of_fft_in_sg)) - : committed_length * (i - static_cast(id_of_fft_in_sg)); - auto loc_view_offset = storage == complex_storage::INTERLEAVED_COMPLEX - ? 2 * factor_sg * factor_wi * subgroup_id - : factor_sg * factor_wi * subgroup_id; - auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; - - subgroup_impl_bluestein_local2global_packed_copy( - output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, - loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data.sg, storage, - global_data); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + auto global_ptr_offset = 2 * committed_length * (i - static_cast(id_of_fft_in_sg)); + auto loc_view_offset = 2 * factor_sg * factor_wi * subgroup_id; + subgroup_impl_bluestein_local_global_packed_copy( + output, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, + n_ffts_worked_on_by_sg, i, n_transforms, global_data); + } else { + auto global_ptr_offset = committed_length * (i - static_cast(id_of_fft_in_sg)); + auto loc_view_offset = factor_sg * factor_wi * subgroup_id; + auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; + subgroup_impl_bluestein_local_global_packed_copy( + output, output_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, + loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data); + } } else { // TODO: Blustein Strided Copy } From 8c3b40b653138238df0e0b4bd6209776a9fa73ce Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 13 Mar 2024 23:08:14 +0000 Subject: [PATCH 11/22] bugfix after refactor --- src/portfft/dispatcher/subgroup_dispatcher.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 4ff6edb2..2dd8a1b2 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -372,8 +372,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag input, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, local_view_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data); } else { - auto global_ptr_offset = 2 * committed_length * (i - static_cast(id_of_fft_in_sg)); - auto local_view_offset = 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; + auto global_ptr_offset = committed_length * (i - static_cast(id_of_fft_in_sg)); + auto local_view_offset = factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; subgroup_impl_bluestein_local_global_packed_copy( input, input_imag, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, @@ -521,14 +521,14 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (is_output_packed) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { auto global_ptr_offset = 2 * committed_length * (i - static_cast(id_of_fft_in_sg)); - auto loc_view_offset = 2 * factor_sg * factor_wi * subgroup_id; + auto loc_view_offset = 2 * factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; subgroup_impl_bluestein_local_global_packed_copy( output, loc_view, committed_length, factor_sg * factor_wi, global_ptr_offset, loc_view_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data); } else { auto global_ptr_offset = committed_length * (i - static_cast(id_of_fft_in_sg)); - auto loc_view_offset = factor_sg * factor_wi * subgroup_id; + auto loc_view_offset = factor_sg * factor_wi * subgroup_id * n_ffts_per_sg; auto loc_view_imag_offset = factor_sg * factor_wi * n_sgs_in_wg; subgroup_impl_bluestein_local_global_packed_copy( From 20e3c78dc56b8db86d40c3982c9b041c6df9ed03 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 07:50:46 +0000 Subject: [PATCH 12/22] doxygens and logging, and lower tolerance value --- src/portfft/common/host_dft.hpp | 7 +- src/portfft/common/subgroup.hpp | 224 ++++++++++++++---- src/portfft/common/transfers.hpp | 148 +++++++++++- .../dispatcher/subgroup_dispatcher.hpp | 77 +++--- src/portfft/utils.hpp | 3 +- test/unit_test/fft_test_utils.hpp | 10 +- 6 files changed, 363 insertions(+), 106 deletions(-) diff --git a/src/portfft/common/host_dft.hpp b/src/portfft/common/host_dft.hpp index a61e361c..9be0f011 100644 --- a/src/portfft/common/host_dft.hpp +++ b/src/portfft/common/host_dft.hpp @@ -41,9 +41,9 @@ void host_naive_dft(std::complex* input, std::complex* output, std::size_t for (std::size_t i = 0; i < fft_size; i++) { complex_t temp = complex_t(0, 0); for (std::size_t j = 0; j < fft_size; j++) { - complex_t multiplier = - complex_t(static_cast(std::cos((-2 * M_PI * static_cast(i * j)) / static_cast(fft_size))), - static_cast(std::sin((-2 * M_PI * static_cast(i * j)) / static_cast(fft_size)))); + // Not using sycl::cospi / sycl::sinpi as std::cos/std::sin provides better accuracy in float and double tests + double theta = -2 * M_PI * static_cast(i * j) / static_cast(fft_size); + complex_t multiplier = complex_t(static_cast(std::cos(theta)), static_cast(std::sin(theta))); temp += input[j] * multiplier; } output[i] = temp; @@ -83,6 +83,7 @@ void host_cooley_tukey(std::complex* input, std::complex* output, std::siz for (std::size_t i = 0; i < n; i++) { for (std::size_t j = 0; j < m; j++) { + // Not using sycl::cospi / sycl::sinpi as std::cos/std::sin provides better accuracy in float and double tests double theta = -2 * M_PI * static_cast(i * j) / static_cast(n * m); output[i * m + j] *= std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); } diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index 6e22b30d..27181113 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -311,25 +311,72 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles sg_twiddles[(k + factor_wi) * factor_sg + n] = twiddle.imag(); } +/** + * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, + * when the data in both local and global memory is in packed format,when the storage scheme is INTERLEAVED_COMPLEX + * + * @tparam SubgroupSize Subgroup size + * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam TIn Global memory Type + * @tparam LocView Type of the view constructed for local memory + * @param global_ptr global memory pointer + * @param loc_view View of the local memory + * @param committed_size Size of the DFT as committed, also the number of complex elements in each transform present in + * global memory + * @param fft_size The padded DFT size, also the number of elements of complex elements in each transform that resides + * in local memory + * @param global_ptr_offset Offset to be applied to the global memory pointer + * @param loc_offset Offset to be applied to the local memory view + * @param n_ffts_in_sg Number of ffts that can be calculated by a single subgroup + * @param transform_id Id of the transform in the kernel + * @param n_transforms Total number of transforms in the kernel + * @param global_data global_data_struct associated with the kernel launch + */ template PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( TIn global_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, - Idx n_ffts_in_sg, IdxGlobal batch_start, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { + Idx n_ffts_in_sg, IdxGlobal transform_id, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && ((i + batch_start) < n_transforms); i++) { + for (Idx i = 0; i < n_ffts_in_sg && ((i + transform_id) < n_transforms); i++) { local_global_packed_copy( global_ptr, loc_view, global_ptr_offset + static_cast(2 * i * committed_size), 2 * i * fft_size + loc_offset, 2 * committed_size, global_data); } } +/** + * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, + * when the data in both local and global memory is in packed format,when the storage scheme is SPLIT_COMPLEX + * + * @tparam SubgroupSize Subgroup size + * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam TIn Global memory Type + * @tparam LocView Type of the view constructed for local memory + * @param global_ptr global memory pointer containing the real part of the data + * @param global_imag_ptr global memory pointer containing the imaginary part of the data + * @param loc_view View of the local memory + * @param committed_size Size of the DFT as committed, also the number of complex elements in each transform present in + * global memory + * @param fft_size The padded DFT size, also the number of elements of complex elements in each transform that resides + * in local memory + * @param global_ptr_offset Offset to be applied to the global memory pointer + * @param loc_offset Offset to be applied to the local memory view + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param n_ffts_in_sg Number of ffts that can be calculated by a single subgroup + * @param transform_id Id of the transform in the kernel + * @param n_transforms Total number of transforms in the kernel + * @param global_data global_data_struct associated with the kernel launch + */ template PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( TIn global_ptr, TIn global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, - IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal batch_start, + IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal transform_id, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && (i + batch_start < n_transforms); i++) { + for (Idx i = 0; i < n_ffts_in_sg && (i + transform_id < n_transforms); i++) { local_global_packed_copy( global_ptr, global_imag_ptr, loc_view, static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset, local_imag_offset, committed_size, global_data); @@ -337,7 +384,7 @@ PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( } /** - * Performs all the computations to be done in the private memory for the subgroup implementation + * Performs all the computations to be done in the private memory for the subgroup level FFT Implementation * * @tparam SubgroupSize Subgroup Size * @tparam T Scalar Type @@ -359,8 +406,7 @@ PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( * @param id_of_wi_in_fft workitem id withing the fft * @param factor_sg Number of workitems participating for one transform * @param factor_wi Number of complex elements per workitem for each transform - * @param sg sub group - * @return PORTFFT_INLINE + * @param global_data global_data_struct associated with the kernel launch */ template PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementwise_multiply apply_load_modifier, @@ -370,13 +416,15 @@ PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementw detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data, const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor, IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, - sycl::sub_group& sg) { + detail::global_data_struct<1>& global_data) { using vec2_t = sycl::vec; vec2_t modifier_vec; if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate before computation of the FFT"); detail::conjugate_inplace(priv, factor_wi); } if (apply_load_modifier == detail::elementwise_multiply::APPLIED) { + global_data.log_message(__func__, "Applying load modifiers"); PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { modifier_vec = *reinterpret_cast( @@ -385,13 +433,15 @@ PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementw priv[2 * j + 1]); } } - sg_dft(priv, sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); + sg_dft(priv, global_data.sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate after computation of the FFT"); detail::conjugate_inplace(priv, factor_wi); } if (apply_store_modifier == detail::elementwise_multiply::APPLIED) { + global_data.log_message(__func__, "Applying store modifiers"); PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { modifier_vec = *reinterpret_cast( @@ -402,6 +452,7 @@ PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementw } if (scale_factor_applied == detail::apply_scale_factor::APPLIED) { + global_data.log_message(__func__, "Applying scale factor"); PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { priv[2 * j] *= scale_factor; @@ -410,108 +461,179 @@ PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementw } } +/** + * Implements the Subgroup level Bluestein algorithm with an addition trip to local memory, when the layout of the data + * in local memory is in BATCH_INTERLEAVED format + * + * @tparam SubgroupSize Subgroup Size + * @tparam T Scalar Type + * @tparam LocTwiddlesView Type of view of the local memory containing the twiddles + * @tparam LocView Type of view of the local memory which stores the data + * @param priv private memory array on which the computations will be done + * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft + * @param loc_view view of the local memory to store the data + * @param load_modifier Global memory pointer containing the load modifier data, assumed aligned to at least + * sycl::vec + * @param store_modifier Global memory pointer containing the store modifier data, assumed aligned to at least + * sycl::vec + * @param twiddles_loc view of the local memory containing the twiddles + * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation + * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation + * @param scale_applied Whether or not scale factor is applied + * @param scale_factor Value of the scaling factor + * @param id_of_wi_in_fft Id of the workitem in the FFT + * @param factor_sg Number of workitems participating for one transform + * @param factor_wi Number of complex elements per workitem for each transform + * @param storage storage scheme of complex values in local memory, SPLIT_COMPLEX or INTERLEAVED_COMPLEX + * @param wi_working Whether or not the workitem participates in the data transfers + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param max_num_batches_local_mem Maximum number of transforms that can be stored in local memory + * @param fft_idx_in_local Id of the transform in local memory + * @param global_data global_data_struct associated with kernel launch + */ template -PORTFFT_INLINE void sg_bluestein_batch_interleaved(T* priv, T* priv_scratch, LocView& loc_view, const T* load_modifier, - const T* store_modifier, LocTwiddlesView& twiddles_loc, - detail::complex_conjugate conjugate_on_load, - detail::complex_conjugate conjugate_on_store, - detail::apply_scale_factor scale_applied, T scale_factor, - Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, - complex_storage storage, bool wi_working, Idx local_imag_offset, - Idx max_num_batches_local_mem, Idx fft_idx_in_local, - sycl::sub_group& sg, detail::global_data_struct<1>& global_data) { +PORTFFT_INLINE void sg_bluestein_batch_interleaved( + T* priv, T* priv_scratch, LocView& loc_view, const T* load_modifier, const T* store_modifier, + LocTwiddlesView& twiddles_loc, detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, T scale_factor, + Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, bool wi_working, Idx local_imag_offset, + Idx max_num_batches_local_mem, Idx fft_idx_in_local, detail::global_data_struct<1>& global_data) { + global_data.log_message_global(__func__, "computing forward FFT and applying scaling factor for the backward phase"); sg_dft_compute( priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, store_modifier, twiddles_loc, static_cast(1. / (static_cast(factor_sg * factor_wi))), 0, id_of_wi_in_fft, - factor_sg, factor_wi, sg); + factor_sg, factor_wi, global_data); + // TODO: Currently local memory is being used to load the data back in natural order for the backward phase, as result + // of sg_dft is transposed However, the Ideal way to this is using shuffles. Implement a batched matrix transpose to + // transpose a matrix stored in the private memory of workitems of a subgroup using shuffles only This we way can even + // the 2 sg_bluestein functions that we have today if (wi_working) { + global_data.log_message(__func__, "storing result of the forward phase back to local memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { local_private_strided_copy<2, Idx>( loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, - factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } else { local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, - factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } } - sycl::group_barrier(sg); + sycl::group_barrier(global_data.sg); if (wi_working) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { + global_data.log_message(__func__, "loading back the result from local memory for the backward phase"); const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; local_private_strided_copy<1, Idx>( loc_view, priv, {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, factor_wi, - global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } else { local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, - factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } } - + global_data.log_message(__func__, "computing backward FFT and applying user provided scale value"); sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), - load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, + global_data); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate on the output"); detail::conjugate_inplace(priv, factor_wi); } } +/** + * + * Implements the Subgroup level Bluestein algorithm with an addition trip to local memory, when the layout of the data + * in local memory is in BATCH_INTERLEAVED format + * + * @tparam SubgroupSize Subgroup Size + * @tparam T Scalar Type + * @tparam LocTwiddlesView Type of view of the local memory containing the twiddles + * @tparam LocView Type of view of the local memory which stores the data + * @param priv private memory array on which the computations will be done + * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft + * @param loc_view view of the local memory to store the data + * @param load_modifier Global memory pointer containing the load modifier data, assumed aligned to at least + * sycl::vec + * @param store_modifier Global memory pointer containing the store modifier data, assumed aligned to at least + * sycl::vec + * @param loc_twiddles view of the local memory containing the twiddles + * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation + * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation + * @param scale_applied Whether or not scale factor is applied + * @param scale_factor Value of the scaling factor + * @param id_of_wi_in_fft Id of the workitem in the FFT + * @param factor_sg Number of workitems participating for one transform + * @param factor_wi Number of complex elements per workitem for each transform + * @param storage storage scheme of complex values in local memory, SPLIT_COMPLEX or INTERLEAVED_COMPLEX + * @param wi_working Whether or not the workitem participates in the data transfers + * @param loc_view_store_offset Offset to be applied to local memory view when storing the data back to local memory + * after forward fft phase + * @param loc_view_load_offset offset to be applied to local memory view when loading the data back to local memory for + * backward fft phase + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param global_data global_data_struct associated with kernel launch + */ template -void sg_bluestein(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& loc_twiddles, const T* load_modifier, - const T* store_modifier, detail::complex_conjugate conjugate_on_load, - detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, - T scale_factor, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, - bool wi_working, Idx loc_offset_store_view, Idx loc_offset_load_view, Idx local_imag_offset, - sycl::sub_group& sg, detail::global_data_struct<1>& global_data) { +void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& loc_twiddles, + const T* load_modifier, const T* store_modifier, detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, + T scale_factor, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, + bool wi_working, Idx loc_view_store_offset, Idx loc_view_load_offset, Idx local_imag_offset, + detail::global_data_struct<1>& global_data) { + global_data.log_message_global(__func__, "computing forward FFT and applying scaling factor for the backward phase"); sg_dft_compute( priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, store_modifier, loc_twiddles, static_cast(1. / static_cast(factor_sg * factor_wi)), 0, id_of_wi_in_fft, - factor_sg, factor_wi, sg); - - sycl::group_barrier(sg); + factor_sg, factor_wi, global_data); if (wi_working) { + global_data.log_message(__func__, "storing result of the forward phase back to local memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_offset_store_view}}}, factor_wi, - global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_store_offset}}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::strided_view local_real_view{loc_view, factor_sg, loc_offset_store_view}; - detail::strided_view local_imag_view{loc_view, factor_sg, loc_offset_store_view + local_imag_offset}; - copy_wi(global_data, priv_real_view, local_real_view, factor_wi); - copy_wi(global_data, priv_imag_view, local_imag_view, factor_wi); + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{factor_sg}, {loc_view_store_offset}}}, + {{{factor_sg}, {loc_view_store_offset + local_imag_offset}}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } } - sycl::group_barrier(sg); + sycl::group_barrier(global_data.sg); if (wi_working) { + global_data.log_message(__func__, "loading back the result from local memory for the backward phase"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local_private_strided_copy<1, Idx>(loc_view, priv, {{{1}, {loc_offset_load_view}}}, factor_wi, global_data, - detail::transfer_direction::LOCAL_TO_PRIVATE); + local_private_strided_copy<1, Idx>(loc_view, priv, {{{1}, {loc_view_load_offset}}}, factor_wi, + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } else { - local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{1}, {loc_offset_load_view}}}, - {{{1}, {loc_offset_load_view + local_imag_offset}}}, factor_wi, global_data, - detail::transfer_direction::LOCAL_TO_PRIVATE); + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{1}, {loc_view_load_offset}}}, + {{{1}, {loc_view_load_offset + local_imag_offset}}}, factor_wi, + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } } - + global_data.log_message(__func__, "computing backward FFT and applying user provided scale value"); sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), - load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, sg); + load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, + global_data); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate on the output"); detail::conjugate_inplace(priv, factor_wi); } } diff --git a/src/portfft/common/transfers.hpp b/src/portfft/common/transfers.hpp index ab7dbae1..0f34e66b 100644 --- a/src/portfft/common/transfers.hpp +++ b/src/portfft/common/transfers.hpp @@ -488,29 +488,80 @@ PORTFFT_INLINE void local2global(detail::global_data_struct<1> global_data, Loca global_data, global, local, total_num_elems, global_offset, local_offset); } +/** + * Driver function to copy data between local and global memory when data is in PACKED format in both, local + * as well as global memory, when the storage scheme is INTERLEAVED_COMPLEX + * + * @tparam Group Group level taking part in the copy, should be one of level::SUBGROUP or level::WORKGROUP + * @tparam Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam SubgroupSize Subgroup Size + * @tparam LocView Type of view of the local memory + * @tparam T Scalar Type + * @param global_ptr Pointer to the input / output global memory + * @param loc_view Local memory view containing the input + * @param global_offset Offset to be applied to the input / output pointer + * @param local_offset Offset to be applied to local memory view + * @param n_elements_to_copy Number of scalar elements to copy + * @param global_data global_data_struct associated with the kernel launch + */ template PORTFFT_INLINE void local_global_packed_copy(T* global_ptr, LocView& loc_view, IdxGlobal global_offset, Idx local_offset, Idx n_elements_to_copy, detail::global_data_struct<1>& global_data) { + global_data.log_message(__func__, "storage scheme: INTERLEAVED_COMPLEX"); if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global_data.log_message(__func__, + "Transferring from global to local memory, number of elements: ", n_elements_to_copy, + " global offset: ", global_offset, " local_offset: ", local_offset); global2local(global_data, global_ptr, loc_view, n_elements_to_copy, global_offset, local_offset); } else { + global_data.log_message(__func__, + "Transferring from global to local memory, number of elements: ", n_elements_to_copy, + " global offset: ", global_offset, " local_offset: ", local_offset); local2global(global_data, loc_view, global_ptr, n_elements_to_copy, local_offset, global_offset); } } +/** + * Driver function to copy data between local and global memory when data is in PACKED format in both, local + * as well as global memory, when the storage scheme is SPLIT_COMPLEX + * + * @tparam Group Group level taking part in the copy, should be one of level::SUBGROUP or level::WORKGROUP + * @tparam Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam SubgroupSize Subgroup Size + * @tparam LocView Type of view of the local memory + * @tparam T Scalar Type + * @param global_ptr Pointer to the input / output global memory containing the real part of the data + * @param global_imag_ptr ointer to the input / output global memory containing the imaginary part of the data + * @param loc_view Local memory view containing the input + * @param global_offset Offset to be applied to the input / output pointer + * @param local_offset Offset to be applied to local memory view + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored. + * @param n_elements_to_copy Number of scalar elements to copy + * @param global_data global_data_struct associated with the kernel launch + */ template PORTFFT_INLINE void local_global_packed_copy(T* global_ptr, T* global_imag_ptr, LocView& loc_view, IdxGlobal global_offset, Idx local_offset, Idx local_imag_offset, Idx n_elements_to_copy, detail::global_data_struct<1>& global_data) { + global_data.log_message(__func__, "storage scheme: SPLIT_COMPLEX"); if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global_data.log_message(__func__, + "Transferring from global to local memory, number of elements: ", n_elements_to_copy, + " global offset: ", global_offset, " local_offset: ", local_offset); global2local(global_data, global_ptr, loc_view, n_elements_to_copy, global_offset, local_offset); global2local(global_data, global_imag_ptr, loc_view, n_elements_to_copy, global_offset, local_offset + local_imag_offset); } else { + global_data.log_message(__func__, + "Transferring from global to local memory, number of elements: ", n_elements_to_copy, + " global offset: ", global_offset, " local_offset: ", local_offset); local2global(global_data, loc_view, global_ptr, n_elements_to_copy, local_offset, global_offset); local2global(global_data, loc_view, global_imag_ptr, n_elements_to_copy, @@ -518,6 +569,27 @@ PORTFFT_INLINE void local_global_packed_copy(T* global_ptr, T* global_imag_ptr, } } +/** + * Driver function for copying data between local and global memory when the data layout is arbitrarily + * strided in either or both, local and global memory, when the storage scheme is INTERLEAVED_COMPLEX + * + * @tparam Group Group level taking part in the copy, should be one of level::SUBGROUP or level::WORKGROUP + * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam GlobalDim Number of dimension of the md_view to be created for the global memory + * @tparam LocalDim Number of dimension of the md_view to be created for the global memory + * @tparam CopyDims Number of dimensions over which the data will be copied + * @tparam T Scalar Type + * @tparam LocView Type of view created for the local memory + * @param global_ptr Pointer to the input / output global memory + * @param loc_view Local memory view containing the input + * @param strides_global An array specifying the strides for the global memory + * @param strides_local An array specifying the strides for the local memory + * @param offset_global Offset value to be applied to the global memory + * @param offset_local Offset value to be applied to the local memory + * @param copy_lengths number of scalars (for each dimension) of the data to copy + * @param global_data global_data_struct associated with the kernel launch + */ template PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, LocView& loc_view, @@ -525,15 +597,42 @@ PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, LocView& loc_view, std::array strides_local, IdxGlobal offset_global, Idx offset_local, std::array copy_lengths, detail::global_data_struct<1> global_data) { + global_data.log_message(__func__, "storage scheme: INTERLEAVED_COMPLEX"); detail::md_view global_md_view{global_ptr, strides_global, offset_global}; detail::md_view local_md_view{loc_view, strides_local, offset_local}; if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global_data.log_message(__func__, "transferring strided data from global to local memory"); copy_group(global_data, global_md_view, local_md_view, copy_lengths); } else { + global_data.log_message(__func__, "transferring strided data from local to global memory"); copy_group(global_data, local_md_view, global_md_view, copy_lengths); } } +/** + * Driver function for copying data between local and global memory when the data layout is arbitrarily + * strided in either or both, local and global memory, when the storage scheme is SPLIT_COMPLEX + * + * @tparam Group Group level taking part in the copy, should be one of level::SUBGROUP or level::WORKGROUP + * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam GlobalDim Number of dimension of the md_view to be created for the global memory + * @tparam LocalDim Number of dimension of the md_view to be created for the global memory + * @tparam CopyDims Number of dimensions over which the data will be copied + * @tparam T Scalar Type + * @tparam LocView Type of view created for the local memory + * @param global_ptr Pointer to the input / output global memory containing the real part of the data + * @param global_imag_ptr ointer to the input / output global memory containing the imaginary part of the data + * @param loc_view View of the local memory + * @param strides_global An array specifying the strides for the global memory + * @param strides_local An array specifying the strides for the local memory + * @param offset_global Offset value to be applied to the global memory + * @param local_offset Offset value to be applied to the local memory + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param copy_lengths number of scalars (for each dimension) of the data to copy + * @param global_data global_data_struct associated with the kernel launch + */ template PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, T* global_imag_ptr, LocView& loc_view, @@ -542,24 +641,44 @@ PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, T* global_imag_ptr, Idx local_offset, Idx local_imag_offset, std::array copy_lengths, detail::global_data_struct<1> global_data) { + global_data.log_message(__func__, "storage scheme: SPLIT_COMPLEX"); detail::md_view global_md_real_view{global_ptr, strides_global, offset_global}; detail::md_view global_md_imag_view{global_imag_ptr, strides_global, offset_global}; detail::md_view local_md_real_view{loc_view, strides_local, local_offset}; detail::md_view local_md_imag_view{loc_view, strides_local, local_offset + local_imag_offset}; if constexpr (Direction == detail::transfer_direction::GLOBAL_TO_LOCAL) { + global_data.log_message(__func__, "transferring strided data from global to local memory"); copy_group(global_data, global_md_real_view, local_md_real_view, copy_lengths); copy_group(global_data, global_md_imag_view, local_md_imag_view, copy_lengths); } else { + global_data.log_message(__func__, "transferring strided data from local to global memory"); copy_group(global_data, local_md_real_view, global_md_real_view, copy_lengths); copy_group(global_data, local_md_imag_view, global_md_imag_view, copy_lengths); } } +/** + * Driver function for copying data between local and private memory when the storage scheme is INTERLEAVED_COMPLEX + * This can also be used when directly copying data from private memory to global memory + * + * @tparam PtrViewNDim Number of Dimension of the local / global memory view + * @tparam IdxType Integer type of the strides and offset of the local / global memory + * @tparam PtrView View type of the local / global memory + * @tparam T Scalar Type + * @param ptr_view View of the local / global memory taking part in the copy + * @param priv Pointer to the private memory array + * @param ptr_view_strides_offsets An array of 2 arrays containing PtrViewNDim elements of IdxType, containing strides + * and offsets for the strided view to be constructed for the local / global memory + * @param num_elements_to_copy Number of scalar elements to copy + * @param direction direction of copy, should be one of LOCAL_TO_PRIVATE, PRIVATE_TO_LOCAL or PRIVATE_TO_GLOBAL + * @param global_data global data struct associated with the kernel launch + */ template PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, T* priv, so_array ptr_view_strides_offsets, - Idx num_elements_to_copy, detail::global_data_struct<1> global_data, - detail::transfer_direction direction) { + Idx num_elements_to_copy, detail::transfer_direction direction, + detail::global_data_struct<1> global_data) { + global_data.log_message(__func__, "storage scheme: INTERLEAVED_COMPLEX"); detail::strided_view ptr_strided_view{ptr_view, std::get<0>(ptr_view_strides_offsets), std::get<1>(ptr_view_strides_offsets)}; if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { @@ -570,12 +689,33 @@ PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, T* priv, } } +/** + * Driver function for copying data between local and private memory when the storage scheme is SPLIT_COMPLEX + * This can also be used when directly copying data from private memory to global memory + * + * @tparam PtrViewNDim Number of Dimension of the local / global memory view + * @tparam IdxType Integer type of the strides and offset of the local / global memory + * @tparam PtrView View type of the local / global memory + * @tparam T Scalar Type + * @param ptr_view View of the local / global memory containing the real component of the data + * @param ptr_imag_view View of the local / global memory containing the imaginary component of the data + * @param priv Pointer to the private memory array + * @param ptr_view_strides_offsets An array of 2 arrays containing PtrViewNDim elements of IdxType, containing strides + * and offsets for the strided view to be constructed for the local / global memory containing the real part of the data + * @param ptr_imag_view_strides_offsets An array of 2 arrays containing PtrViewNDim elements of IdxType, containing + * strides and offsets for the strided view to be constructed for the local / global memory containing the imaginary + * part of the data + * @param num_elements_to_copy Number of elements to copy + * @param direction direction of copy, should be one of LOCAL_TO_PRIVATE, PRIVATE_TO_LOCAL or PRIVATE_TO_GLOBAL + * @param global_data global data struct associated with the kernel launch + */ template PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, PtrView& ptr_imag_view, T* priv, so_array ptr_view_strides_offsets, so_array ptr_imag_view_strides_offsets, - Idx num_elements_to_copy, detail::global_data_struct<1> global_data, - detail::transfer_direction direction) { + Idx num_elements_to_copy, detail::transfer_direction direction, + detail::global_data_struct<1> global_data) { + global_data.log_message(__func__, "storage scheme: INTERLEAVED_COMPLEX"); detail::strided_view ptr_strided_real_view{ptr_view, std::get<0>(ptr_view_strides_offsets), std::get<1>(ptr_view_strides_offsets)}; detail::strided_view ptr_strided_imag_view{ptr_imag_view, std::get<0>(ptr_imag_view_strides_offsets), diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 2dd8a1b2..2a6d81de 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -224,13 +224,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag local_private_strided_copy<1, Idx>( loc_view, priv, {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, - factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } else { local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, - factor_wi, global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } @@ -240,13 +240,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag sg_dft_compute(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, loc_twiddles, scaling_factor, modifier_offset, id_of_wi_in_fft, factor_sg, - factor_wi, global_data.sg); + factor_wi, global_data); } else { sg_bluestein_batch_interleaved( priv, wi_private_scratch, loc_view, load_modifier_data, store_modifier_data, loc_twiddles, conjugate_on_load, conjugate_on_store, apply_scale_factor, scaling_factor, id_of_wi_in_fft, factor_sg, factor_wi, storage, working_inner, local_imag_offset, max_num_batches_local_mem, fft_idx_in_local, - global_data.sg, global_data); + global_data); } // Async DMA can start here for the next set of load/store modifiers. if (working_inner) { @@ -262,13 +262,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag static_cast(2 * id_of_wi_in_fft); if (storage == complex_storage::INTERLEAVED_COMPLEX) { so_array output_stride_offset{{{static_cast(factor_sg)}, {output_offset}}}; - local_private_strided_copy<1, IdxGlobal>(output, priv, output_stride_offset, factor_wi, global_data, - detail::transfer_direction::PRIVATE_TO_GLOBAL); + local_private_strided_copy<1, IdxGlobal>(output, priv, output_stride_offset, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } else { so_array output_stride_offset{{{static_cast(factor_sg)}, {output_offset / 2}}}; local_private_strided_copy<1, IdxGlobal>(output, output_imag, priv, output_stride_offset, - output_stride_offset, factor_wi, global_data, - detail::transfer_direction::PRIVATE_TO_GLOBAL); + output_stride_offset, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } } } else { @@ -282,13 +282,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag local_private_strided_copy<2, Idx>( loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, factor_wi, - global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } else { local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, - factor_wi, global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } } } @@ -392,12 +392,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (storage == complex_storage::INTERLEAVED_COMPLEX) { local_private_strided_copy<1, Idx>( loc_view, priv, {{{1}, {subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi}}}, factor_wi, - global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } else { local_private_strided_copy<1, Idx>( loc_view, loc_view, priv, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi}}}, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}}}, factor_wi, - global_data, detail::transfer_direction::LOCAL_TO_PRIVATE); + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } @@ -407,23 +407,22 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, loc_twiddles, scaling_factor, static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - id_of_wi_in_fft, factor_sg, factor_wi, global_data.sg); + id_of_wi_in_fft, factor_sg, factor_wi, global_data); } else { - // Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; - // subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; - // subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; - auto loc_offset_store_view = - storage == complex_storage::INTERLEAVED_COMPLEX - ? subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft - : subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; - auto loc_offset_load_view = storage == complex_storage::INTERLEAVED_COMPLEX - ? subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi - : subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi; - sg_bluestein(priv, wi_private_scratch, loc_view, loc_twiddles, load_modifier_data, - store_modifier_data, conjugate_on_load, conjugate_on_store, apply_scale_factor, - scaling_factor, id_of_wi_in_fft, factor_sg, factor_wi, storage, working, - loc_offset_store_view, loc_offset_load_view, local_imag_offset, global_data.sg, - global_data); + Idx loc_offset_store_view; + Idx loc_offset_load_view; + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + loc_offset_store_view = + subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; + loc_offset_load_view = subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi; + } else { + loc_offset_store_view = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; + loc_offset_load_view = subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi; + } + sg_bluestein_packed( + priv, wi_private_scratch, loc_view, loc_twiddles, load_modifier_data, store_modifier_data, + conjugate_on_load, conjugate_on_store, apply_scale_factor, scaling_factor, id_of_wi_in_fft, factor_sg, + factor_wi, storage, working, loc_offset_store_view, loc_offset_load_view, local_imag_offset, global_data); } if (working) { global_data.log_dump_private("data in registers after scaling:", priv, n_reals_per_wi); @@ -440,16 +439,16 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag static_cast(id_of_fft_in_sg * n_reals_per_fft) + static_cast(id_of_wi_in_fft * 2); local_private_strided_copy<1, IdxGlobal>( - output, priv, {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, global_data, - detail::transfer_direction::PRIVATE_TO_GLOBAL); + output, priv, {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } else { IdxGlobal output_offset = i * static_cast(n_cplx_per_sg) + static_cast(id_of_fft_in_sg * fft_size) + static_cast(id_of_wi_in_fft); local_private_strided_copy<1, IdxGlobal>( output, output_imag, priv, {{{static_cast(factor_sg)}, {output_offset}}}, - {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, global_data, - detail::transfer_direction::PRIVATE_TO_GLOBAL); + {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } } } else if (is_output_batch_interleaved && algorithm == detail::fft_algorithm::COOLEY_TUKEY) { @@ -459,14 +458,14 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag local_private_strided_copy<2, IdxGlobal>(output, priv, {{{static_cast(factor_sg), n_transforms}, {static_cast(2 * id_of_wi_in_fft), 2 * i}}}, - factor_wi, global_data, - detail::transfer_direction::PRIVATE_TO_GLOBAL); + factor_wi, detail::transfer_direction::PRIVATE_TO_GLOBAL, + global_data); } else { so_array global_stride_offset{ {{static_cast(factor_sg), n_transforms}, {static_cast(id_of_wi_in_fft), i}}}; local_private_strided_copy<2, IdxGlobal>(output, output_imag, priv, global_stride_offset, - global_stride_offset, factor_wi, global_data, - detail::transfer_direction::PRIVATE_TO_GLOBAL); + global_stride_offset, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } } } else { @@ -477,12 +476,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag Idx loc_view_offset = subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, factor_wi, - global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } else { Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, {{{factor_sg}, {loc_view_offset + local_imag_offset}}}, factor_wi, - global_data, detail::transfer_direction::PRIVATE_TO_LOCAL); + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } } sycl::group_barrier(global_data.sg); @@ -570,6 +569,8 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn for (Idx i = 0; i < factor_sg; i++) { for (Idx j = 0; j < factor_wi; j++) { + // Not using sycl::cospi / sycl::sinpi as std::cos/std::sin provides better accuracy in float and double tests + // Also why this was moved to host, this way the tolerance value needs to be bumped up by a smaller value double theta = -2 * M_PI * static_cast(i * j) / static_cast(factor_wi * factor_sg); auto twiddle = std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); host_twiddles[static_cast(j * factor_sg + i)] = twiddle.real(); diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index 2d63b0d5..6b6c54b9 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -119,7 +119,8 @@ std::optional factorize_input_impl(IdxGlobal factor_size, F&& check_a * implementations. The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or * not as an input, and should return a boolean indicating whether or not the factor size can fit in any of the * implementation. - * @return Whether or not a prime sized that does not fit in workitem implementation was encountered + * @return Whether or not a prime sized value that does not fit in workitem implementation was encountered during + * factorization */ template bool factorize_input(IdxGlobal input_size, F&& check_and_select_target_level) { diff --git a/test/unit_test/fft_test_utils.hpp b/test/unit_test/fft_test_utils.hpp index 12790d7e..41c82142 100644 --- a/test/unit_test/fft_test_utils.hpp +++ b/test/unit_test/fft_test_utils.hpp @@ -339,14 +339,6 @@ std::enable_if_t check_fft( host_output_imag.size(), {fft_event}); } queue.wait_and_throw(); - // std::cout << "PRINTING REFERENCE DATA " << std::endl; - // for (auto n : host_reference_output) { - // std::cout << n << " "; - // } - // std::cout << std::endl; - // for (auto n : host_output) { - // std::cout << n << " "; - // } if constexpr (Storage == complex_storage::SPLIT_COMPLEX) { verify_dft(desc, host_reference_output, host_output, tolerance, host_reference_output_imag, host_output_imag); @@ -474,7 +466,7 @@ void run_test(const test_params& params) { auto num_prime_sizes = std::count_if(params.lengths.begin(), params.lengths.end(), [](const std::size_t l) { return detail::factorize(l) == std::size_t(1); }); if (num_prime_sizes > 0) { - tolerance *= 10; + tolerance *= 5; } portfft::detail::dump_host("host_input:", host_input.data(), host_input.size()); portfft::detail::dump_host("host_input_imag:", host_input.data(), host_input.size()); From ae929d33c7622ae1c29806b5ebb5002c9c7fba44 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 08:01:24 +0000 Subject: [PATCH 13/22] remove unused shuffle_transpose function --- src/portfft/common/transpose.hpp | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/portfft/common/transpose.hpp b/src/portfft/common/transpose.hpp index b496a739..75775112 100644 --- a/src/portfft/common/transpose.hpp +++ b/src/portfft/common/transpose.hpp @@ -98,29 +98,6 @@ PORTFFT_INLINE inline void generic_transpose(IdxGlobal N, IdxGlobal M, Idx tile_ } } } - -template -PORTFFT_INLINE void shuffle_transpose(T* priv, T* output, Idx lda, Idx ldb, sycl::sub_group sg) { - Idx sg_local_linear_id = static_cast(sg.get_local_linear_id()); - Idx id_of_thread_in_fft = sg_local_linear_id % lda; - Idx matrix_start_lane_id = (sg_local_linear_id - id_of_thread_in_fft) & (SubgroupSize - 1); - Idx lane_id_relative_to_start = id_of_thread_in_fft & (lda - 1); - - PORTFFT_UNROLL - for (Idx id_of_element_in_wi = 0; id_of_element_in_wi < ldb; id_of_element_in_wi++) { - Idx relative_target_lane_id = ((lane_id_relative_to_start + id_of_element_in_wi) & (ldb - 1)) * (lda / ldb) + - (lane_id_relative_to_start / ldb); - Idx target_lane_id = matrix_start_lane_id + relative_target_lane_id; - Idx store_address = (sg_local_linear_id + id_of_element_in_wi) & (ldb - 1); - Idx target_address = ((ldb - id_of_element_in_wi) + (sg_local_linear_id / (lda / ldb))) & (ldb - 1); - T& real_value = priv[2 * target_address]; - T& complex_value = priv[2 * target_address + 1]; - output[2 * store_address] = sycl::select_from_group(sg, real_value, static_cast(target_lane_id)); - output[2 * store_address + 1] = - sycl::select_from_group(sg, complex_value, static_cast(target_lane_id)); - } -} - } // namespace detail } // namespace portfft #endif From 708893cdce75d288a5cbb8f8e66210c85f9c051f Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 08:03:57 +0000 Subject: [PATCH 14/22] remove unused bluestein header from workgroup_dispatcher.hpp --- src/portfft/dispatcher/workgroup_dispatcher.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index db0860a7..8d9b08c0 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -21,7 +21,6 @@ #ifndef PORTFFT_DISPATCHER_WORKGROUP_DISPATCHER_HPP #define PORTFFT_DISPATCHER_WORKGROUP_DISPATCHER_HPP -#include "portfft/common/bluestein.hpp" #include "portfft/common/helpers.hpp" #include "portfft/common/logging.hpp" #include "portfft/common/memory_views.hpp" From 594d2246ddef2814a2e2e602bd1d6ba207d89130 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 08:10:21 +0000 Subject: [PATCH 15/22] remove unused headers from subgroup.hpp --- src/portfft/common/subgroup.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index 27181113..68b2fa43 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -25,9 +25,7 @@ #include "helpers.hpp" #include "portfft/common/logging.hpp" -#include "portfft/common/memory_views.hpp" #include "portfft/common/transfers.hpp" -#include "portfft/common/transpose.hpp" #include "portfft/defines.hpp" #include "portfft/enums.hpp" #include "twiddle.hpp" From 5f1ab4d6277c4323a927d8a8b45a6575b5372328 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 08:31:49 +0000 Subject: [PATCH 16/22] add missing array header --- src/portfft/defines.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/portfft/defines.hpp b/src/portfft/defines.hpp index 8af51ba0..75afb240 100644 --- a/src/portfft/defines.hpp +++ b/src/portfft/defines.hpp @@ -21,6 +21,7 @@ #ifndef PORTFFT_DEFINES_HPP #define PORTFFT_DEFINES_HPP +#include #include #ifdef PORTFFT_KERNEL_LOG From 99d8cfb4b9bb570f329800c05e8d6f65a5ab7163 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 10:37:12 +0000 Subject: [PATCH 17/22] slightly bump tolerance value for tests to pass on Nvidia --- test/unit_test/fft_test_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit_test/fft_test_utils.hpp b/test/unit_test/fft_test_utils.hpp index 41c82142..c282499f 100644 --- a/test/unit_test/fft_test_utils.hpp +++ b/test/unit_test/fft_test_utils.hpp @@ -466,7 +466,7 @@ void run_test(const test_params& params) { auto num_prime_sizes = std::count_if(params.lengths.begin(), params.lengths.end(), [](const std::size_t l) { return detail::factorize(l) == std::size_t(1); }); if (num_prime_sizes > 0) { - tolerance *= 5; + tolerance *= 7; // Smallest value by which the tolerance needs to be increased } portfft::detail::dump_host("host_input:", host_input.data(), host_input.size()); portfft::detail::dump_host("host_input_imag:", host_input.data(), host_input.size()); From 243c793c13000d71b83057db46ef4fdba097be49 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Mar 2024 10:38:34 +0000 Subject: [PATCH 18/22] format --- test/unit_test/fft_test_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit_test/fft_test_utils.hpp b/test/unit_test/fft_test_utils.hpp index c282499f..3503412b 100644 --- a/test/unit_test/fft_test_utils.hpp +++ b/test/unit_test/fft_test_utils.hpp @@ -466,7 +466,7 @@ void run_test(const test_params& params) { auto num_prime_sizes = std::count_if(params.lengths.begin(), params.lengths.end(), [](const std::size_t l) { return detail::factorize(l) == std::size_t(1); }); if (num_prime_sizes > 0) { - tolerance *= 7; // Smallest value by which the tolerance needs to be increased + tolerance *= 7; // Smallest value by which the tolerance needs to be increased } portfft::detail::dump_host("host_input:", host_input.data(), host_input.size()); portfft::detail::dump_host("host_input_imag:", host_input.data(), host_input.size()); From 3a7795332125490b38ba401fb4cae4d32a5ce8dd Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 18 Mar 2024 14:48:57 +0000 Subject: [PATCH 19/22] review comments 1 --- .clang-tidy | 1 - src/portfft/committed_descriptor_impl.hpp | 37 +- src/portfft/common/subgroup.hpp | 641 ------------------ src/portfft/common/subgroup_bluestein.hpp | 283 ++++++++ src/portfft/common/subgroup_ct.hpp | 397 +++++++++++ src/portfft/common/transfers.hpp | 17 +- src/portfft/common/workgroup.hpp | 2 +- src/portfft/defines.hpp | 10 +- src/portfft/descriptor_validation.hpp | 2 +- src/portfft/dispatcher/global_dispatcher.hpp | 2 +- .../dispatcher/subgroup_dispatcher.hpp | 87 +-- src/portfft/utils.hpp | 2 +- test/unit_test/fft_test_utils.hpp | 1 - 13 files changed, 768 insertions(+), 714 deletions(-) delete mode 100644 src/portfft/common/subgroup.hpp create mode 100644 src/portfft/common/subgroup_bluestein.hpp create mode 100644 src/portfft/common/subgroup_ct.hpp diff --git a/.clang-tidy b/.clang-tidy index 1fe9e338..0b3225d5 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -14,7 +14,6 @@ Checks: > performance-*, -performance-avoid-endl, readability-*, - -readability-magic-numbers, -readability-function-cognitive-complexity, -readability-identifier-length, -readability-named-parameter, diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index 597d28f7..0b5caccc 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -30,7 +30,7 @@ #include #include "common/exceptions.hpp" -#include "common/subgroup.hpp" +#include "common/subgroup_ct.hpp" #include "defines.hpp" #include "enums.hpp" #include "specialization_constant.hpp" @@ -148,28 +148,28 @@ class committed_descriptor_impl { std::vector transpose_kernels; std::shared_ptr factors_and_scan; detail::level level; - // The size of DFT transform which will be computed for the given dimension + // The size of DFT transform which will be computed for the given dimension. Will be different from the + // committed_length when the Bluestein / Rader algorithms are used std::size_t length; - // The committed length for the particular dimension, will be different from length in the case of bluestein and - // radar fft algorithms + // The committed length (as in the user specified length) for the particular dimension std::size_t committed_length; Idx used_sg_size; Idx num_batches_in_l2; Idx num_factors; - bool is_prime; + detail::fft_algorithm algorithm; dimension_struct(std::vector forward_kernels, std::vector backward_kernels, detail::level level, std::size_t length, std::size_t committed_length, Idx used_sg_size, - bool is_prime) + detail::fft_algorithm algorithm) : forward_kernels(std::move(forward_kernels)), backward_kernels(std::move(backward_kernels)), level(level), length(length), committed_length(committed_length), used_sg_size(used_sg_size), - is_prime(is_prime) { - if (is_prime && level != detail::level::SUBGROUP) { - throw unsupported_configuration("Prime sizes that not fit in the subgroup implementation are not supported"); + algorithm(algorithm) { + if (algorithm == detail::fft_algorithm::BLUESTEIN && level != detail::level::SUBGROUP) { + throw unsupported_configuration("Prime sizes that do not fit in the subgroup implementation are not supported"); } } }; @@ -215,9 +215,9 @@ class committed_descriptor_impl { * set of kernels that need to be JIT compiled. * * @tparam SubgroupSize size of the subgroup - * @param fft_size The size of the dft transform - * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, - * vector of kernel ids, factors + * @param fft_size The size for which kernel needs to be prepared + * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, the + * size of the fft for which the implementation was prepared and the vector of kernel ids, factors */ template std::tuple prepare_implementation(IdxGlobal fft_size) { @@ -595,6 +595,15 @@ class committed_descriptor_impl { set_spec_constants_driver(top_level, prepared_vec, direction::FORWARD, dimension_num); auto backward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::BACKWARD, dimension_num); + detail::fft_algorithm algorithm; + if (fft_size == params.lengths[dimension_num]) { + algorithm = detail::fft_algorithm::COOLEY_TUKEY; + } else if (fft_size > params.lengths[dimension_num]) { + algorithm = detail::fft_algorithm::BLUESTEIN; + } else { + throw internal_error("Invalid FFT size encountered while preparing the implementation"); + } + if (forward_kernels.has_value() && backward_kernels.has_value()) { return {forward_kernels.value(), backward_kernels.value(), @@ -602,7 +611,7 @@ class committed_descriptor_impl { fft_size, params.lengths[dimension_num], SubgroupSize, - fft_size != params.lengths[dimension_num]}; + algorithm}; } } } @@ -949,7 +958,7 @@ class committed_descriptor_impl { const auto input_layout = detail::get_layout(params, compute_direction); const auto output_layout = detail::get_layout(params, inv(compute_direction)); - if (dimensions.back().is_prime) { + if (dimensions.back().algorithm == detail::fft_algorithm::BLUESTEIN) { if (input_layout == detail::layout::UNPACKED || output_layout == detail::layout::UNPACKED) { throw unsupported_configuration("Unsupported configuration for prime sized DFTs"); } diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp deleted file mode 100644 index 68b2fa43..00000000 --- a/src/portfft/common/subgroup.hpp +++ /dev/null @@ -1,641 +0,0 @@ -/*************************************************************************** - * - * Copyright (C) Codeplay Software Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Codeplay's portFFT - * - **************************************************************************/ - -#ifndef PORTFFT_COMMON_SUBGROUP_HPP -#define PORTFFT_COMMON_SUBGROUP_HPP - -#include - -#include "helpers.hpp" -#include "portfft/common/logging.hpp" -#include "portfft/common/transfers.hpp" -#include "portfft/defines.hpp" -#include "portfft/enums.hpp" -#include "twiddle.hpp" -#include "twiddle_calc.hpp" -#include "workitem.hpp" - -namespace portfft { -namespace detail { - -/* -`sg_dft` calculates a DFT by a subgroup on values that are already loaded into private memory of the workitems in the -subgroup. It needs twiddle factors precalculated by `sg_calc_twiddles`. It handles the first factor by cross subgroup -DFT calling `cross_sg_dispatcher` and the second one by workitem implementation - calling `wi_dft`. It does twiddle -multiplication inbetween, but does not transpose. Transposition is supposed to be done when storing the values back to -the local memory. - -The size of the DFT performed by this function is `N * M` - for the arguments `N` and `M`. `N` workitems work jointly on -one DFT, so at most `subgroup_size / N` DFTs can be performed by one subgroup at a time. If `N` does not evenly divide -`subgroup_size`, extra workitems perform dummy computations. However, they must also call `sg_dft`, as it uses group -functions. - -On input, each of the `N` workitems hold `M` consecutive complex input values. On output, each of the workitems holds -complex values that are strided with stride `N` and consecutive workitems have consecutive values. - -`cross_sg_dft` calculates DFT across workitems, with each workitem contributing one complex value as input and output of -the computation. If the size of the subgroup is large enough compared to FFT size, a subgroup can calculate multiple -DFTs at once (the same holds true for `cross_sg_cooley_tukey_dft` and `cross_sg_naive_dft`). It calls either -`cross_sg_cooley_tukey_dft` (for composite sizes) or `cross_sg_naive_dft` (for prime sizes). - -`cross_sg_cooley_tukey_dft` calculates DFT of a composite size across workitems. It calls `cross_sg_dft` for each of the -factors and does transposition and twiddle multiplication inbetween. - -`cross_sg_naive_dft` calculates DFT across workitems using naive DFT algorithm. -*/ - -// forward declaration -template -PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg); - -/** - * Calculates DFT using naive algorithm by using workitems of one subgroup. - * Each workitem holds one input and one output complex value. - * - * @tparam T type of the scalar to work on - * @param[in,out] real real component of the input/output complex value for one - * workitem - * @param[in,out] imag imaginary component of the input/output complex value for - * one workitem - * @param fft_size size of the DFT transform - * @param stride Stride between workitems working on consecutive values of one - * DFT - * @param sg subgroup - */ -template -PORTFFT_INLINE void cross_sg_naive_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg) { - if (fft_size == 2 && (stride & (stride - 1)) == 0) { - Idx local_id = static_cast(sg.get_local_linear_id()); - Idx idx_out = (local_id / stride) % 2; - - T multi_re = (idx_out & 1) ? T(-1) : T(1); - T res_real = real * multi_re; - T res_imag = imag * multi_re; - - res_real += sycl::permute_group_by_xor(sg, real, static_cast(stride)); - res_imag += sycl::permute_group_by_xor(sg, imag, static_cast(stride)); - - real = res_real; - imag = res_imag; - } else { - Idx local_id = static_cast(sg.get_local_linear_id()); - Idx idx_out = (local_id / stride) % fft_size; - Idx fft_start = local_id - idx_out * stride; - - T res_real = 0; - T res_imag = 0; - - // IGC doesn't unroll this loop and generates a warning when called from workgroup impl. - PORTFFT_UNROLL - for (Idx idx_in = 0; idx_in < fft_size; idx_in++) { - T multi_re = twiddle::Re[fft_size][idx_in * idx_out % fft_size]; - T multi_im = twiddle::Im[fft_size][idx_in * idx_out % fft_size]; - - Idx source_wi_id = fft_start + idx_in * stride; - - T cur_real = sycl::select_from_group(sg, real, static_cast(source_wi_id)); - T cur_imag = sycl::select_from_group(sg, imag, static_cast(source_wi_id)); - - // multiply cur and multi - T tmp_real; - T tmp_imag; - detail::multiply_complex(cur_real, cur_imag, multi_re, multi_im, tmp_real, tmp_imag); - res_real += tmp_real; - res_imag += tmp_imag; - } - - real = res_real; - imag = res_imag; - } -} - -/** - * Transposes values held by workitems of a subgroup. Transposes rectangles of - * size N*M. Each of the rectangles can be strided. - * - * @tparam T type of the scalar to work on - * @param[in,out] real real component of the input/output complex value for one - * workitem - * @param[in,out] imag imaginary component of the input/output complex value for - * one workitem - * @param factor_n inner - contiguous size on input, outer size on output - * @param factor_m outer size on input, inner - contiguous size on output - * @param stride Stride between consecutive values of one rectangle - * @param sg subgroup - */ -template -PORTFFT_INLINE void cross_sg_transpose(T& real, T& imag, Idx factor_n, Idx factor_m, Idx stride, sycl::sub_group& sg) { - Idx local_id = static_cast(sg.get_local_linear_id()); - Idx index_in_outer_dft = (local_id / stride) % (factor_n * factor_m); - Idx k = index_in_outer_dft % factor_n; // index in the contiguous factor/fft - Idx n = index_in_outer_dft / factor_n; // index of the contiguous factor/fft - Idx fft_start = local_id - index_in_outer_dft * stride; - Idx source_wi_id = fft_start + stride * (k * factor_m + n); - real = sycl::select_from_group(sg, real, static_cast(source_wi_id)); - imag = sycl::select_from_group(sg, imag, static_cast(source_wi_id)); -} - -/** - * Calculates DFT using Cooley-Tukey FFT algorithm. Size of the problem is N*M. - * Each workitem holds one input and one output complex value. - * - * @tparam SubgroupSize Size of subgroup in kernel - * @tparam RecursionLevel level of recursion in SG dft - * @tparam T type of the scalar to work on - * @param[in,out] real real component of the input/output complex value for one - * workitem - * @param[in,out] imag imaginary component of the input/output complex value for - * one workitem - * @param factor_n the first factor of the problem size - * @param factor_m the second factor of the problem size - * @param stride Stride between workitems working on consecutive values of one - * DFT - * @param sg subgroup - */ -template -PORTFFT_INLINE void cross_sg_cooley_tukey_dft(T& real, T& imag, Idx factor_n, Idx factor_m, Idx stride, - sycl::sub_group& sg) { - Idx local_id = static_cast(sg.get_local_linear_id()); - Idx index_in_outer_dft = (local_id / stride) % (factor_n * factor_m); - Idx k = index_in_outer_dft % factor_n; // index in the contiguous factor/fft - Idx n = index_in_outer_dft / factor_n; // index of the contiguous factor/fft - - // factor N - cross_sg_dft(real, imag, factor_n, factor_m * stride, sg); - // transpose - cross_sg_transpose(real, imag, factor_n, factor_m, stride, sg); - T multi_re = twiddle::Re[factor_n * factor_m][k * n]; - T multi_im = twiddle::Im[factor_n * factor_m][k * n]; - detail::multiply_complex(real, imag, multi_re, multi_im, real, imag); - // factor M - cross_sg_dft(real, imag, factor_m, factor_n * stride, sg); -} - -/** - * Calculates DFT using FFT algorithm. Each workitem holds one input and one - * output complex value. - * - * @tparam SubgroupSize Size of subgroup in kernel - * @tparam RecursionLevel level of recursion in SG dft - * @tparam T type of the scalar to work on - * @param[in,out] real real component of the input/output complex value for one - * workitem - * @param[in,out] imag imaginary component of the input/output complex value for - * one workitem - * @param fft_size Size of the DFT - * @param stride Stride between workitems working on consecutive values of one - * DFT - * @param sg subgroup - */ -template -PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg) { - constexpr Idx MaxRecursionLevel = detail::int_log2(SubgroupSize); - if constexpr (RecursionLevel < MaxRecursionLevel) { - const Idx f0 = detail::factorize(fft_size); - if (f0 >= 2 && fft_size / f0 >= 2) { - cross_sg_cooley_tukey_dft(real, imag, fft_size / f0, f0, stride, sg); - } else { - cross_sg_naive_dft(real, imag, fft_size, stride, sg); - } - } -} - -/** - * Factorizes a number into two factors, so that one of them will maximal below - or equal to subgroup size. - * @tparam T type of the number to factorize - * @param N the number to factorize - * @param sg_size subgroup size - * @return the factor below or equal to subgroup size - */ -template -PORTFFT_INLINE constexpr T factorize_sg(T N, Idx sg_size) { - if constexpr (PORTFFT_SLOW_SG_SHUFFLES) { - return 1; - } else { - for (T i = static_cast(sg_size); i > 1; i--) { - if (N % i == 0) { - return i; - } - } - return 1; - } -} - -/** - * Checks whether a problem can be solved with sub-group implementation - * without reg spilling. - * @tparam Scalar type of the real scalar used for the computation - * @param N Size of the problem, in complex values - * @param sg_size Size of the sub-group - * @return true if the problem fits in the registers - */ -template -constexpr bool fits_in_sg(IdxGlobal N, Idx sg_size) { - IdxGlobal factor_sg = factorize_sg(N, sg_size); - IdxGlobal factor_wi = N / factor_sg; - return fits_in_wi(factor_wi); -} - -}; // namespace detail - -/** - * Calculates FFT of size N*M using workitems in a subgroup. Works in place. The - * end result needs to be transposed when storing it to the local memory! - * - * @tparam SubgroupSize Size of subgroup in kernel - * @tparam T type of the scalar used for computations - * @param inout pointer to private memory where the input/output data is - * @param sg subgroup - * @param factor_wi number of elements per workitem - * @param factor_sg number of workitems in a subgroup that work on one FFT - * @param sg_twiddles twiddle factors to use - calculated by sg_calc_twiddles in - * commit - * @param private_scratch Scratch memory for wi implementation - */ -template -PORTFFT_INLINE void sg_dft(T* inout, sycl::sub_group& sg, Idx factor_wi, Idx factor_sg, const T* sg_twiddles, - T* private_scratch) { - Idx idx_of_wi_in_fft = static_cast(sg.get_local_linear_id()) % factor_sg; - // IGC doesn't unroll this loop and generates a warning when called from workgroup impl. - PORTFFT_UNROLL - for (Idx idx_of_element_in_wi = 0; idx_of_element_in_wi < factor_wi; idx_of_element_in_wi++) { - T& real = inout[2 * idx_of_element_in_wi]; - T& imag = inout[2 * idx_of_element_in_wi + 1]; - - if (factor_sg > 1) { - detail::cross_sg_dft(real, imag, factor_sg, 1, sg); - if (idx_of_element_in_wi > 0) { - T twiddle_real = sg_twiddles[idx_of_element_in_wi * factor_sg + idx_of_wi_in_fft]; - T twiddle_imag = sg_twiddles[(idx_of_element_in_wi + factor_wi) * factor_sg + idx_of_wi_in_fft]; - detail::multiply_complex(real, imag, twiddle_real, twiddle_imag, real, imag); - } - } - }; - wi_dft<0>(inout, inout, factor_wi, 1, 1, private_scratch); -} - -/** - * Calculates a twiddle factor for subgroup implementation. - * - * @tparam T type of the scalar used for computations - * @param factor_sg number of workitems in a subgroup that work on one FFT - * @param factor_wi number of elements per workitem - * @param n index of the twiddle to calculate in the direction of factor_sg - * @param k index of the twiddle to calculate in the direction of factor_wi - * @param sg_twiddles destination into which to store the twiddles - */ -template -void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles) { - std::complex twiddle = detail::calculate_twiddle(n * k, factor_sg * factor_wi); - sg_twiddles[k * factor_sg + n] = twiddle.real(); - sg_twiddles[(k + factor_wi) * factor_sg + n] = twiddle.imag(); -} - -/** - * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, - * when the data in both local and global memory is in packed format,when the storage scheme is INTERLEAVED_COMPLEX - * - * @tparam SubgroupSize Subgroup size - * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or - * transfer_direction::GLOBAL_TO_LOCAL - * @tparam TIn Global memory Type - * @tparam LocView Type of the view constructed for local memory - * @param global_ptr global memory pointer - * @param loc_view View of the local memory - * @param committed_size Size of the DFT as committed, also the number of complex elements in each transform present in - * global memory - * @param fft_size The padded DFT size, also the number of elements of complex elements in each transform that resides - * in local memory - * @param global_ptr_offset Offset to be applied to the global memory pointer - * @param loc_offset Offset to be applied to the local memory view - * @param n_ffts_in_sg Number of ffts that can be calculated by a single subgroup - * @param transform_id Id of the transform in the kernel - * @param n_transforms Total number of transforms in the kernel - * @param global_data global_data_struct associated with the kernel launch - */ -template -PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( - TIn global_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, - Idx n_ffts_in_sg, IdxGlobal transform_id, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { - PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && ((i + transform_id) < n_transforms); i++) { - local_global_packed_copy( - global_ptr, loc_view, global_ptr_offset + static_cast(2 * i * committed_size), - 2 * i * fft_size + loc_offset, 2 * committed_size, global_data); - } -} - -/** - * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, - * when the data in both local and global memory is in packed format,when the storage scheme is SPLIT_COMPLEX - * - * @tparam SubgroupSize Subgroup size - * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or - * transfer_direction::GLOBAL_TO_LOCAL - * @tparam TIn Global memory Type - * @tparam LocView Type of the view constructed for local memory - * @param global_ptr global memory pointer containing the real part of the data - * @param global_imag_ptr global memory pointer containing the imaginary part of the data - * @param loc_view View of the local memory - * @param committed_size Size of the DFT as committed, also the number of complex elements in each transform present in - * global memory - * @param fft_size The padded DFT size, also the number of elements of complex elements in each transform that resides - * in local memory - * @param global_ptr_offset Offset to be applied to the global memory pointer - * @param loc_offset Offset to be applied to the local memory view - * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is - * stored - * @param n_ffts_in_sg Number of ffts that can be calculated by a single subgroup - * @param transform_id Id of the transform in the kernel - * @param n_transforms Total number of transforms in the kernel - * @param global_data global_data_struct associated with the kernel launch - */ -template -PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( - TIn global_ptr, TIn global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, - IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal transform_id, - IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { - PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && (i + transform_id < n_transforms); i++) { - local_global_packed_copy( - global_ptr, global_imag_ptr, loc_view, static_cast(i * committed_size) + global_ptr_offset, - i * fft_size + loc_offset, local_imag_offset, committed_size, global_data); - } -} - -/** - * Performs all the computations to be done in the private memory for the subgroup level FFT Implementation - * - * @tparam SubgroupSize Subgroup Size - * @tparam T Scalar Type - * @tparam LocView View of the local memory - * @param priv private memory array on which the computations will be done - * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft - * @param apply_load_modifier Whether or not modifiers need to be applied before the fft computation - * @param apply_store_modifier Whether or not the modifiers need to be applied after the fft computation - * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation - * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation - * @param scale_factor_applied Whether or not scale factor is applied - * @param load_modifier_data Global memory pointer containing the load modifier data, assumed aligned to at least - * sycl::vec - * @param store_modifier_data Global memory pointer containing the store modifier data, assumed aligned to at least - * sycl::vec - * @param twiddles_loc_view View of the local memory containing the twiddles - * @param scale_factor Value of the scale factor - * @param modifier_start_offset offset to be applied to the load/store modifier pointers - * @param id_of_wi_in_fft workitem id withing the fft - * @param factor_sg Number of workitems participating for one transform - * @param factor_wi Number of complex elements per workitem for each transform - * @param global_data global_data_struct associated with the kernel launch - */ -template -PORTFFT_INLINE void sg_dft_compute(T* priv, T* private_scratch, detail::elementwise_multiply apply_load_modifier, - detail::elementwise_multiply apply_store_modifier, - detail::complex_conjugate conjugate_on_load, - detail::complex_conjugate conjugate_on_store, - detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data, - const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor, - IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, - detail::global_data_struct<1>& global_data) { - using vec2_t = sycl::vec; - vec2_t modifier_vec; - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - global_data.log_message(__func__, "Applying complex conjugate before computation of the FFT"); - detail::conjugate_inplace(priv, factor_wi); - } - if (apply_load_modifier == detail::elementwise_multiply::APPLIED) { - global_data.log_message(__func__, "Applying load modifiers"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - modifier_vec = *reinterpret_cast( - &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); - detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], - priv[2 * j + 1]); - } - } - sg_dft(priv, global_data.sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); - - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - global_data.log_message(__func__, "Applying complex conjugate after computation of the FFT"); - detail::conjugate_inplace(priv, factor_wi); - } - - if (apply_store_modifier == detail::elementwise_multiply::APPLIED) { - global_data.log_message(__func__, "Applying store modifiers"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - modifier_vec = *reinterpret_cast( - &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); - detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], - priv[2 * j + 1]); - } - } - - if (scale_factor_applied == detail::apply_scale_factor::APPLIED) { - global_data.log_message(__func__, "Applying scale factor"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - priv[2 * j] *= scale_factor; - priv[2 * j + 1] *= scale_factor; - } - } -} - -/** - * Implements the Subgroup level Bluestein algorithm with an addition trip to local memory, when the layout of the data - * in local memory is in BATCH_INTERLEAVED format - * - * @tparam SubgroupSize Subgroup Size - * @tparam T Scalar Type - * @tparam LocTwiddlesView Type of view of the local memory containing the twiddles - * @tparam LocView Type of view of the local memory which stores the data - * @param priv private memory array on which the computations will be done - * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft - * @param loc_view view of the local memory to store the data - * @param load_modifier Global memory pointer containing the load modifier data, assumed aligned to at least - * sycl::vec - * @param store_modifier Global memory pointer containing the store modifier data, assumed aligned to at least - * sycl::vec - * @param twiddles_loc view of the local memory containing the twiddles - * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation - * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation - * @param scale_applied Whether or not scale factor is applied - * @param scale_factor Value of the scaling factor - * @param id_of_wi_in_fft Id of the workitem in the FFT - * @param factor_sg Number of workitems participating for one transform - * @param factor_wi Number of complex elements per workitem for each transform - * @param storage storage scheme of complex values in local memory, SPLIT_COMPLEX or INTERLEAVED_COMPLEX - * @param wi_working Whether or not the workitem participates in the data transfers - * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is - * stored - * @param max_num_batches_local_mem Maximum number of transforms that can be stored in local memory - * @param fft_idx_in_local Id of the transform in local memory - * @param global_data global_data_struct associated with kernel launch - */ -template -PORTFFT_INLINE void sg_bluestein_batch_interleaved( - T* priv, T* priv_scratch, LocView& loc_view, const T* load_modifier, const T* store_modifier, - LocTwiddlesView& twiddles_loc, detail::complex_conjugate conjugate_on_load, - detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, T scale_factor, - Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, bool wi_working, Idx local_imag_offset, - Idx max_num_batches_local_mem, Idx fft_idx_in_local, detail::global_data_struct<1>& global_data) { - global_data.log_message_global(__func__, "computing forward FFT and applying scaling factor for the backward phase"); - sg_dft_compute( - priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, - conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, - store_modifier, twiddles_loc, static_cast(1. / (static_cast(factor_sg * factor_wi))), 0, id_of_wi_in_fft, - factor_sg, factor_wi, global_data); - - // TODO: Currently local memory is being used to load the data back in natural order for the backward phase, as result - // of sg_dft is transposed However, the Ideal way to this is using shuffles. Implement a batched matrix transpose to - // transpose a matrix stored in the private memory of workitems of a subgroup using shuffles only This we way can even - // the 2 sg_bluestein functions that we have today - if (wi_working) { - global_data.log_message(__func__, "storing result of the forward phase back to local memory"); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local_private_strided_copy<2, Idx>( - loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, - factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); - } else { - local_private_strided_copy<2, Idx>( - loc_view, loc_view, priv, {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, - {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, - factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); - } - } - - sycl::group_barrier(global_data.sg); - if (wi_working) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - global_data.log_message(__func__, "loading back the result from local memory for the backward phase"); - const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; - local_private_strided_copy<1, Idx>( - loc_view, priv, - {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, factor_wi, - detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); - } else { - local_private_strided_copy<2, Idx>( - loc_view, loc_view, priv, {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, - {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, - factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); - } - } - global_data.log_message(__func__, "computing backward FFT and applying user provided scale value"); - sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, - detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), - load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, - global_data); - - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - global_data.log_message(__func__, "Applying complex conjugate on the output"); - detail::conjugate_inplace(priv, factor_wi); - } -} - -/** - * - * Implements the Subgroup level Bluestein algorithm with an addition trip to local memory, when the layout of the data - * in local memory is in BATCH_INTERLEAVED format - * - * @tparam SubgroupSize Subgroup Size - * @tparam T Scalar Type - * @tparam LocTwiddlesView Type of view of the local memory containing the twiddles - * @tparam LocView Type of view of the local memory which stores the data - * @param priv private memory array on which the computations will be done - * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft - * @param loc_view view of the local memory to store the data - * @param load_modifier Global memory pointer containing the load modifier data, assumed aligned to at least - * sycl::vec - * @param store_modifier Global memory pointer containing the store modifier data, assumed aligned to at least - * sycl::vec - * @param loc_twiddles view of the local memory containing the twiddles - * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation - * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation - * @param scale_applied Whether or not scale factor is applied - * @param scale_factor Value of the scaling factor - * @param id_of_wi_in_fft Id of the workitem in the FFT - * @param factor_sg Number of workitems participating for one transform - * @param factor_wi Number of complex elements per workitem for each transform - * @param storage storage scheme of complex values in local memory, SPLIT_COMPLEX or INTERLEAVED_COMPLEX - * @param wi_working Whether or not the workitem participates in the data transfers - * @param loc_view_store_offset Offset to be applied to local memory view when storing the data back to local memory - * after forward fft phase - * @param loc_view_load_offset offset to be applied to local memory view when loading the data back to local memory for - * backward fft phase - * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is - * stored - * @param global_data global_data_struct associated with kernel launch - */ -template -void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& loc_twiddles, - const T* load_modifier, const T* store_modifier, detail::complex_conjugate conjugate_on_load, - detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, - T scale_factor, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, - bool wi_working, Idx loc_view_store_offset, Idx loc_view_load_offset, Idx local_imag_offset, - detail::global_data_struct<1>& global_data) { - global_data.log_message_global(__func__, "computing forward FFT and applying scaling factor for the backward phase"); - sg_dft_compute( - priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, - conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, - store_modifier, loc_twiddles, static_cast(1. / static_cast(factor_sg * factor_wi)), 0, id_of_wi_in_fft, - factor_sg, factor_wi, global_data); - - if (wi_working) { - global_data.log_message(__func__, "storing result of the forward phase back to local memory"); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_store_offset}}}, factor_wi, - detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); - } else { - local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{factor_sg}, {loc_view_store_offset}}}, - {{{factor_sg}, {loc_view_store_offset + local_imag_offset}}}, factor_wi, - detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); - } - } - - sycl::group_barrier(global_data.sg); - - if (wi_working) { - global_data.log_message(__func__, "loading back the result from local memory for the backward phase"); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local_private_strided_copy<1, Idx>(loc_view, priv, {{{1}, {loc_view_load_offset}}}, factor_wi, - detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); - } else { - local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{1}, {loc_view_load_offset}}}, - {{{1}, {loc_view_load_offset + local_imag_offset}}}, factor_wi, - detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); - } - } - global_data.log_message(__func__, "computing backward FFT and applying user provided scale value"); - sg_dft_compute(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, - detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), - load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, - global_data); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - global_data.log_message(__func__, "Applying complex conjugate on the output"); - detail::conjugate_inplace(priv, factor_wi); - } -} - -}; // namespace portfft - -#endif diff --git a/src/portfft/common/subgroup_bluestein.hpp b/src/portfft/common/subgroup_bluestein.hpp new file mode 100644 index 00000000..1c8be538 --- /dev/null +++ b/src/portfft/common/subgroup_bluestein.hpp @@ -0,0 +1,283 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_SUBGROUP_BLUESTEIN_HPP +#define PORTFFT_COMMON_SUBGROUP_BLUESTEIN_HPP + +#include "helpers.hpp" +#include "portfft/common/logging.hpp" +#include "portfft/common/subgroup_ct.hpp" +#include "portfft/common/transfers.hpp" +#include "portfft/defines.hpp" +#include "portfft/enums.hpp" + +namespace portfft { + +/** + * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, + * when the data in both local and global memory is in packed format,when the storage scheme is INTERLEAVED_COMPLEX + * + * @tparam SubgroupSize Subgroup size + * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam TIn Global memory Type + * @tparam LocView Type of the view constructed for local memory + * @param global_ptr global memory pointer + * @param loc_view View of the local memory + * @param committed_size Size of the DFT as committed, also the number of complex elements in each transform present in + * global memory + * @param fft_size The padded DFT size, also the number of elements of complex elements in each transform that resides + * in local memory + * @param global_ptr_offset Offset to be applied to the global memory pointer + * @param loc_offset Offset to be applied to the local memory view + * @param n_ffts_in_sg Number of ffts that can be calculated by a single subgroup + * @param transform_id Id of the transform in the kernel + * @param n_transforms Total number of transforms in the kernel + * @param global_data global_data_struct associated with the kernel launch + */ +template +PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( + TIn global_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, + Idx n_ffts_in_sg, IdxGlobal transform_id, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg && ((i + transform_id) < n_transforms); i++) { + local_global_packed_copy( + global_ptr, loc_view, global_ptr_offset + static_cast(2 * i * committed_size), + 2 * i * fft_size + loc_offset, 2 * committed_size, global_data); + } +} + +/** + * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, + * when the data in both local and global memory is in packed format,when the storage scheme is SPLIT_COMPLEX + * + * @tparam SubgroupSize Subgroup size + * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or + * transfer_direction::GLOBAL_TO_LOCAL + * @tparam TIn Global memory Type + * @tparam LocView Type of the view constructed for local memory + * @param global_ptr global memory pointer containing the real part of the data + * @param global_imag_ptr global memory pointer containing the imaginary part of the data + * @param loc_view View of the local memory + * @param committed_size Size of the DFT as committed, also the number of complex elements in each transform present in + * global memory + * @param fft_size The padded DFT size, also the number of elements of complex elements in each transform that resides + * in local memory + * @param global_ptr_offset Offset to be applied to the global memory pointer + * @param loc_offset Offset to be applied to the local memory view + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param n_ffts_in_sg Number of ffts that can be calculated by a single subgroup + * @param transform_id Id of the transform in the kernel + * @param n_transforms Total number of transforms in the kernel + * @param global_data global_data_struct associated with the kernel launch + */ +template +PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( + TIn global_ptr, TIn global_imag_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, + IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal transform_id, + IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { + PORTFFT_UNROLL + for (Idx i = 0; i < n_ffts_in_sg && (i + transform_id < n_transforms); i++) { + local_global_packed_copy( + global_ptr, global_imag_ptr, loc_view, static_cast(i * committed_size) + global_ptr_offset, + i * fft_size + loc_offset, local_imag_offset, committed_size, global_data); + } +} + +/** + * Implements the Subgroup level Bluestein algorithm when the layout of the data + * in local memory is in BATCH_INTERLEAVED format + * + * @tparam SubgroupSize Subgroup Size + * @tparam T Scalar Type + * @tparam LocTwiddlesView Type of view of the local memory containing the twiddles + * @tparam LocView Type of view of the local memory which stores the data + * @param priv private memory array on which the computations will be done + * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft + * @param loc_view view of the local memory to store the data + * @param load_modifier Global memory pointer containing the load modifier data, assumed aligned to at least + * sycl::vec + * @param store_modifier Global memory pointer containing the store modifier data, assumed aligned to at least + * sycl::vec + * @param twiddles_loc view of the local memory containing the twiddles + * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation + * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation + * @param scale_applied Whether or not scale factor is applied + * @param scale_factor Value of the scaling factor + * @param id_of_wi_in_fft Id of the workitem in the FFT + * @param factor_sg Number of workitems participating for one transform + * @param factor_wi Number of complex elements per workitem for each transform + * @param storage storage scheme of complex values in local memory, SPLIT_COMPLEX or INTERLEAVED_COMPLEX + * @param wi_working Whether or not the workitem participates in the data transfers + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param max_num_batches_local_mem Maximum number of transforms that can be stored in local memory + * @param fft_idx_in_local Id of the transform in local memory + * @param global_data global_data_struct associated with kernel launch + */ +template +PORTFFT_INLINE void sg_bluestein_batch_interleaved( + T* priv, T* priv_scratch, LocView& loc_view, const T* load_modifier, const T* store_modifier, + LocTwiddlesView& twiddles_loc, detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, T scale_factor, + Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, bool wi_working, Idx local_imag_offset, + Idx max_num_batches_local_mem, Idx fft_idx_in_local, detail::global_data_struct<1>& global_data) { + global_data.log_message_global(__func__, "computing forward FFT and applying scaling factor for the backward phase"); + sg_cooley_tukey( + priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, + conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, + store_modifier, twiddles_loc, static_cast(1. / (static_cast(factor_sg * factor_wi))), 0, id_of_wi_in_fft, + factor_sg, factor_wi, global_data); + + // TODO: Currently local memory is being used to load the data back in natural order for the backward phase, as the + // result of sg_dft is transposed. However, the ideal way to this is using shuffles. Implement a batched matrix + // transpose to transpose a matrix stored in the private memory of workitems of a subgroup using shuffles only. his we + // way can even avoid the 2 sg_bluestein functions that we have today + if (wi_working) { + global_data.log_message(__func__, "storing result of the forward phase back to local memory"); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + local_private_strided_copy<2, Idx>( + loc_view, priv, {{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}, + factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); + } else { + local_private_strided_copy<2, Idx>( + loc_view, loc_view, priv, {{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}, + {{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); + } + } + + sycl::group_barrier(global_data.sg); + if (wi_working) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + global_data.log_message(__func__, "loading back the result from local memory for the backward phase"); + const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; + local_private_strided_copy<1, Idx>( + loc_view, priv, + {{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}, factor_wi, + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); + } else { + local_private_strided_copy<2, Idx>( + loc_view, loc_view, priv, {{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}, + {{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}, + factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); + } + } + global_data.log_message(__func__, "computing backward FFT and applying user provided scale value"); + sg_cooley_tukey(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, + detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), + load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, + global_data); + + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate on the output"); + detail::conjugate_inplace(priv, factor_wi); + } +} + +/** + * + * Implements the Subgroup level Bluestein algorithm when the layout of the data + * in local memory is in BATCH_INTERLEAVED format + * + * @tparam SubgroupSize Subgroup Size + * @tparam T Scalar Type + * @tparam LocTwiddlesView Type of view of the local memory containing the twiddles + * @tparam LocView Type of view of the local memory which stores the data + * @param priv private memory array on which the computations will be done + * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft + * @param loc_view view of the local memory to store the data + * @param load_modifier Global memory pointer containing the load modifier data, assumed aligned to at least + * sycl::vec + * @param store_modifier Global memory pointer containing the store modifier data, assumed aligned to at least + * sycl::vec + * @param loc_twiddles view of the local memory containing the twiddles + * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation + * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation + * @param scale_applied Whether or not scale factor is applied + * @param scale_factor Value of the scaling factor + * @param id_of_wi_in_fft Id of the workitem in the FFT + * @param factor_sg Number of workitems participating for one transform + * @param factor_wi Number of complex elements per workitem for each transform + * @param storage storage scheme of complex values in local memory, SPLIT_COMPLEX or INTERLEAVED_COMPLEX + * @param wi_working Whether or not the workitem participates in the data transfers + * @param loc_view_store_offset Offset to be applied to local memory view when storing the data back to local memory + * after forward fft phase + * @param loc_view_load_offset offset to be applied to local memory view when loading the data back to local memory for + * backward fft phase + * @param local_imag_offset Number of elements in local memory after which the imaginary component of the values is + * stored + * @param global_data global_data_struct associated with kernel launch + */ +template +void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddlesView& loc_twiddles, + const T* load_modifier, const T* store_modifier, detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, detail::apply_scale_factor scale_applied, + T scale_factor, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, complex_storage storage, + bool wi_working, Idx loc_view_store_offset, Idx loc_view_load_offset, Idx local_imag_offset, + detail::global_data_struct<1>& global_data) { + global_data.log_message_global(__func__, "computing forward FFT and applying scaling factor for the backward phase"); + sg_cooley_tukey( + priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, + conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, + store_modifier, loc_twiddles, static_cast(1. / static_cast(factor_sg * factor_wi)), 0, id_of_wi_in_fft, + factor_sg, factor_wi, global_data); + + if (wi_working) { + global_data.log_message(__func__, "storing result of the forward phase back to local memory"); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + local_private_strided_copy<1, Idx>(loc_view, priv, {{factor_sg}, {loc_view_store_offset}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); + } else { + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{factor_sg}, {loc_view_store_offset}}, + {{factor_sg}, {loc_view_store_offset + local_imag_offset}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); + } + } + + sycl::group_barrier(global_data.sg); + + if (wi_working) { + global_data.log_message(__func__, "loading back the result from local memory for the backward phase"); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + local_private_strided_copy<1, Idx>(loc_view, priv, {{1}, {loc_view_load_offset}}, factor_wi, + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); + } else { + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{1}, {loc_view_load_offset}}, + {{1}, {loc_view_load_offset + local_imag_offset}}, factor_wi, + detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); + } + } + global_data.log_message(__func__, "computing backward FFT and applying user provided scale value"); + sg_cooley_tukey(priv, priv_scratch, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, + detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), + load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, + global_data); + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate on the output"); + detail::conjugate_inplace(priv, factor_wi); + } +} +} // namespace portfft + +#endif \ No newline at end of file diff --git a/src/portfft/common/subgroup_ct.hpp b/src/portfft/common/subgroup_ct.hpp new file mode 100644 index 00000000..831f8586 --- /dev/null +++ b/src/portfft/common/subgroup_ct.hpp @@ -0,0 +1,397 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_SUBGROUP_HPP +#define PORTFFT_COMMON_SUBGROUP_HPP + +#include + +#include "helpers.hpp" +#include "portfft/common/logging.hpp" +#include "portfft/common/transfers.hpp" +#include "portfft/defines.hpp" +#include "portfft/enums.hpp" +#include "twiddle.hpp" +#include "twiddle_calc.hpp" +#include "workitem.hpp" + +namespace portfft { +namespace detail { + +/* +`sg_dft` calculates a DFT by a subgroup on values that are already loaded into private memory of the workitems in the +subgroup. It needs twiddle factors precalculated by `sg_calc_twiddles`. It handles the first factor by cross subgroup +DFT calling `cross_sg_dispatcher` and the second one by workitem implementation - calling `wi_dft`. It does twiddle +multiplication inbetween, but does not transpose. Transposition is supposed to be done when storing the values back to +the local memory. + +The size of the DFT performed by this function is `N * M` - for the arguments `N` and `M`. `N` workitems work jointly on +one DFT, so at most `subgroup_size / N` DFTs can be performed by one subgroup at a time. If `N` does not evenly divide +`subgroup_size`, extra workitems perform dummy computations. However, they must also call `sg_dft`, as it uses group +functions. + +On input, each of the `N` workitems hold `M` consecutive complex input values. On output, each of the workitems holds +complex values that are strided with stride `N` and consecutive workitems have consecutive values. + +`cross_sg_dft` calculates DFT across workitems, with each workitem contributing one complex value as input and output of +the computation. If the size of the subgroup is large enough compared to FFT size, a subgroup can calculate multiple +DFTs at once (the same holds true for `cross_sg_cooley_tukey_dft` and `cross_sg_naive_dft`). It calls either +`cross_sg_cooley_tukey_dft` (for composite sizes) or `cross_sg_naive_dft` (for prime sizes). + +`cross_sg_cooley_tukey_dft` calculates DFT of a composite size across workitems. It calls `cross_sg_dft` for each of the +factors and does transposition and twiddle multiplication inbetween. + +`cross_sg_naive_dft` calculates DFT across workitems using naive DFT algorithm. +*/ + +// forward declaration +template +PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg); + +/** + * Calculates DFT using naive algorithm by using workitems of one subgroup. + * Each workitem holds one input and one output complex value. + * + * @tparam T type of the scalar to work on + * @param[in,out] real real component of the input/output complex value for one + * workitem + * @param[in,out] imag imaginary component of the input/output complex value for + * one workitem + * @param fft_size size of the DFT transform + * @param stride Stride between workitems working on consecutive values of one + * DFT + * @param sg subgroup + */ +template +PORTFFT_INLINE void cross_sg_naive_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg) { + if (fft_size == 2 && (stride & (stride - 1)) == 0) { + Idx local_id = static_cast(sg.get_local_linear_id()); + Idx idx_out = (local_id / stride) % 2; + + T multi_re = (idx_out & 1) ? T(-1) : T(1); + T res_real = real * multi_re; + T res_imag = imag * multi_re; + + res_real += sycl::permute_group_by_xor(sg, real, static_cast(stride)); + res_imag += sycl::permute_group_by_xor(sg, imag, static_cast(stride)); + + real = res_real; + imag = res_imag; + } else { + Idx local_id = static_cast(sg.get_local_linear_id()); + Idx idx_out = (local_id / stride) % fft_size; + Idx fft_start = local_id - idx_out * stride; + + T res_real = 0; + T res_imag = 0; + + // IGC doesn't unroll this loop and generates a warning when called from workgroup impl. + PORTFFT_UNROLL + for (Idx idx_in = 0; idx_in < fft_size; idx_in++) { + T multi_re = twiddle::Re[fft_size][idx_in * idx_out % fft_size]; + T multi_im = twiddle::Im[fft_size][idx_in * idx_out % fft_size]; + + Idx source_wi_id = fft_start + idx_in * stride; + + T cur_real = sycl::select_from_group(sg, real, static_cast(source_wi_id)); + T cur_imag = sycl::select_from_group(sg, imag, static_cast(source_wi_id)); + + // multiply cur and multi + T tmp_real; + T tmp_imag; + detail::multiply_complex(cur_real, cur_imag, multi_re, multi_im, tmp_real, tmp_imag); + res_real += tmp_real; + res_imag += tmp_imag; + } + + real = res_real; + imag = res_imag; + } +} + +/** + * Transposes values held by workitems of a subgroup. Transposes rectangles of + * size N*M. Each of the rectangles can be strided. + * + * @tparam T type of the scalar to work on + * @param[in,out] real real component of the input/output complex value for one + * workitem + * @param[in,out] imag imaginary component of the input/output complex value for + * one workitem + * @param factor_n inner - contiguous size on input, outer size on output + * @param factor_m outer size on input, inner - contiguous size on output + * @param stride Stride between consecutive values of one rectangle + * @param sg subgroup + */ +template +PORTFFT_INLINE void cross_sg_transpose(T& real, T& imag, Idx factor_n, Idx factor_m, Idx stride, sycl::sub_group& sg) { + Idx local_id = static_cast(sg.get_local_linear_id()); + Idx index_in_outer_dft = (local_id / stride) % (factor_n * factor_m); + Idx k = index_in_outer_dft % factor_n; // index in the contiguous factor/fft + Idx n = index_in_outer_dft / factor_n; // index of the contiguous factor/fft + Idx fft_start = local_id - index_in_outer_dft * stride; + Idx source_wi_id = fft_start + stride * (k * factor_m + n); + real = sycl::select_from_group(sg, real, static_cast(source_wi_id)); + imag = sycl::select_from_group(sg, imag, static_cast(source_wi_id)); +} + +/** + * Calculates DFT using Cooley-Tukey FFT algorithm. Size of the problem is N*M. + * Each workitem holds one input and one output complex value. + * + * @tparam SubgroupSize Size of subgroup in kernel + * @tparam RecursionLevel level of recursion in SG dft + * @tparam T type of the scalar to work on + * @param[in,out] real real component of the input/output complex value for one + * workitem + * @param[in,out] imag imaginary component of the input/output complex value for + * one workitem + * @param factor_n the first factor of the problem size + * @param factor_m the second factor of the problem size + * @param stride Stride between workitems working on consecutive values of one + * DFT + * @param sg subgroup + */ +template +PORTFFT_INLINE void cross_sg_cooley_tukey_dft(T& real, T& imag, Idx factor_n, Idx factor_m, Idx stride, + sycl::sub_group& sg) { + Idx local_id = static_cast(sg.get_local_linear_id()); + Idx index_in_outer_dft = (local_id / stride) % (factor_n * factor_m); + Idx k = index_in_outer_dft % factor_n; // index in the contiguous factor/fft + Idx n = index_in_outer_dft / factor_n; // index of the contiguous factor/fft + + // factor N + cross_sg_dft(real, imag, factor_n, factor_m * stride, sg); + // transpose + cross_sg_transpose(real, imag, factor_n, factor_m, stride, sg); + T multi_re = twiddle::Re[factor_n * factor_m][k * n]; + T multi_im = twiddle::Im[factor_n * factor_m][k * n]; + detail::multiply_complex(real, imag, multi_re, multi_im, real, imag); + // factor M + cross_sg_dft(real, imag, factor_m, factor_n * stride, sg); +} + +/** + * Calculates DFT using FFT algorithm. Each workitem holds one input and one + * output complex value. + * + * @tparam SubgroupSize Size of subgroup in kernel + * @tparam RecursionLevel level of recursion in SG dft + * @tparam T type of the scalar to work on + * @param[in,out] real real component of the input/output complex value for one + * workitem + * @param[in,out] imag imaginary component of the input/output complex value for + * one workitem + * @param fft_size Size of the DFT + * @param stride Stride between workitems working on consecutive values of one + * DFT + * @param sg subgroup + */ +template +PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg) { + constexpr Idx MaxRecursionLevel = detail::int_log2(SubgroupSize); + if constexpr (RecursionLevel < MaxRecursionLevel) { + const Idx f0 = detail::factorize(fft_size); + if (f0 >= 2 && fft_size / f0 >= 2) { + cross_sg_cooley_tukey_dft(real, imag, fft_size / f0, f0, stride, sg); + } else { + cross_sg_naive_dft(real, imag, fft_size, stride, sg); + } + } +} + +/** + * Factorizes a number into two factors, so that one of them will maximal below + or equal to subgroup size. + * @tparam T type of the number to factorize + * @param N the number to factorize + * @param sg_size subgroup size + * @return the factor below or equal to subgroup size + */ +template +PORTFFT_INLINE constexpr T factorize_sg(T N, Idx sg_size) { + if constexpr (PORTFFT_SLOW_SG_SHUFFLES) { + return 1; + } else { + for (T i = static_cast(sg_size); i > 1; i--) { + if (N % i == 0) { + return i; + } + } + return 1; + } +} + +/** + * Checks whether a problem can be solved with sub-group implementation + * without reg spilling. + * @tparam Scalar type of the real scalar used for the computation + * @param N Size of the problem, in complex values + * @param sg_size Size of the sub-group + * @return true if the problem fits in the registers + */ +template +constexpr bool fits_in_sg(IdxGlobal N, Idx sg_size) { + IdxGlobal factor_sg = factorize_sg(N, sg_size); + IdxGlobal factor_wi = N / factor_sg; + return fits_in_wi(factor_wi); +} + +}; // namespace detail + +/** + * Calculates FFT of size N*M using workitems in a subgroup. Works in place. The + * end result needs to be transposed when storing it to the local memory! + * + * @tparam SubgroupSize Size of subgroup in kernel + * @tparam T type of the scalar used for computations + * @param inout pointer to private memory where the input/output data is + * @param sg subgroup + * @param factor_wi number of elements per workitem + * @param factor_sg number of workitems in a subgroup that work on one FFT + * @param sg_twiddles twiddle factors to use - calculated by sg_calc_twiddles in + * commit + * @param private_scratch Scratch memory for wi implementation + */ +template +PORTFFT_INLINE void sg_dft(T* inout, sycl::sub_group& sg, Idx factor_wi, Idx factor_sg, const T* sg_twiddles, + T* private_scratch) { + Idx idx_of_wi_in_fft = static_cast(sg.get_local_linear_id()) % factor_sg; + // IGC doesn't unroll this loop and generates a warning when called from workgroup impl. + PORTFFT_UNROLL + for (Idx idx_of_element_in_wi = 0; idx_of_element_in_wi < factor_wi; idx_of_element_in_wi++) { + T& real = inout[2 * idx_of_element_in_wi]; + T& imag = inout[2 * idx_of_element_in_wi + 1]; + + if (factor_sg > 1) { + detail::cross_sg_dft(real, imag, factor_sg, 1, sg); + if (idx_of_element_in_wi > 0) { + T twiddle_real = sg_twiddles[idx_of_element_in_wi * factor_sg + idx_of_wi_in_fft]; + T twiddle_imag = sg_twiddles[(idx_of_element_in_wi + factor_wi) * factor_sg + idx_of_wi_in_fft]; + detail::multiply_complex(real, imag, twiddle_real, twiddle_imag, real, imag); + } + } + }; + wi_dft<0>(inout, inout, factor_wi, 1, 1, private_scratch); +} + +/** + * Calculates a twiddle factor for subgroup implementation. + * + * @tparam T type of the scalar used for computations + * @param factor_sg number of workitems in a subgroup that work on one FFT + * @param factor_wi number of elements per workitem + * @param n index of the twiddle to calculate in the direction of factor_sg + * @param k index of the twiddle to calculate in the direction of factor_wi + * @param sg_twiddles destination into which to store the twiddles + */ +template +void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles) { + std::complex twiddle = detail::calculate_twiddle(n * k, factor_sg * factor_wi); + sg_twiddles[k * factor_sg + n] = twiddle.real(); + sg_twiddles[(k + factor_wi) * factor_sg + n] = twiddle.imag(); +} + +/** + * Performs the following sequence of operations as required for subgroup level cooley tukey implementation - + * Taking conjugate of the input + * Applying the load modifiers + * call to sg_dft + * Applying the store modifiers + * Taking conjugate of the output + * + * @tparam SubgroupSize Subgroup Size + * @tparam T Scalar Type + * @tparam LocView View of the local memory + * @param priv private memory array on which the computations will be done + * @param private_scratch Scratch private memory to be passed to the wi_dft as a part of sg_dft + * @param apply_load_modifier Whether or not modifiers need to be applied before the fft computation + * @param apply_store_modifier Whether or not the modifiers need to be applied after the fft computation + * @param conjugate_on_load Whether or not conjugation of the input is to be done before the fft computation + * @param conjugate_on_store Whether or not conjugation of the input is to be done after the fft computation + * @param scale_factor_applied Whether or not scale factor is applied + * @param load_modifier_data Global memory pointer containing the load modifier data, assumed aligned to at least + * sycl::vec + * @param store_modifier_data Global memory pointer containing the store modifier data, assumed aligned to at least + * sycl::vec + * @param twiddles_loc_view View of the local memory containing the twiddles + * @param scale_factor Value of the scale factor + * @param modifier_start_offset offset to be applied to the load/store modifier pointers + * @param id_of_wi_in_fft workitem id withing the fft + * @param factor_sg Number of workitems participating for one transform + * @param factor_wi Number of complex elements per workitem for each transform + * @param global_data global_data_struct associated with the kernel launch + */ +template +PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::elementwise_multiply apply_load_modifier, + detail::elementwise_multiply apply_store_modifier, + detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, + detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data, + const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor, + IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, + detail::global_data_struct<1>& global_data) { + using vec2_t = sycl::vec; + vec2_t modifier_vec; + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate before computation of the FFT"); + detail::conjugate_inplace(priv, factor_wi); + } + if (apply_load_modifier == detail::elementwise_multiply::APPLIED) { + global_data.log_message(__func__, "Applying load modifiers"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } + } + sg_dft(priv, global_data.sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); + + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + global_data.log_message(__func__, "Applying complex conjugate after computation of the FFT"); + detail::conjugate_inplace(priv, factor_wi); + } + + if (apply_store_modifier == detail::elementwise_multiply::APPLIED) { + global_data.log_message(__func__, "Applying store modifiers"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } + } + + if (scale_factor_applied == detail::apply_scale_factor::APPLIED) { + global_data.log_message(__func__, "Applying scale factor"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + priv[2 * j] *= scale_factor; + priv[2 * j + 1] *= scale_factor; + } + } +} + +}; // namespace portfft + +#endif diff --git a/src/portfft/common/transfers.hpp b/src/portfft/common/transfers.hpp index 0f34e66b..0efdeda2 100644 --- a/src/portfft/common/transfers.hpp +++ b/src/portfft/common/transfers.hpp @@ -675,12 +675,11 @@ PORTFFT_INLINE void local_global_strided_copy(T* global_ptr, T* global_imag_ptr, */ template PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, T* priv, - so_array ptr_view_strides_offsets, + stride_offset_struct ptr_view_strides_offsets, Idx num_elements_to_copy, detail::transfer_direction direction, detail::global_data_struct<1> global_data) { global_data.log_message(__func__, "storage scheme: INTERLEAVED_COMPLEX"); - detail::strided_view ptr_strided_view{ptr_view, std::get<0>(ptr_view_strides_offsets), - std::get<1>(ptr_view_strides_offsets)}; + detail::strided_view ptr_strided_view{ptr_view, ptr_view_strides_offsets.strides, ptr_view_strides_offsets.offsets}; if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { copy_wi<2>(global_data, ptr_strided_view, priv, num_elements_to_copy); } else if (direction == detail::transfer_direction::PRIVATE_TO_LOCAL || @@ -711,15 +710,15 @@ PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, T* priv, */ template PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, PtrView& ptr_imag_view, T* priv, - so_array ptr_view_strides_offsets, - so_array ptr_imag_view_strides_offsets, + stride_offset_struct ptr_view_strides_offsets, + stride_offset_struct ptr_imag_view_strides_offsets, Idx num_elements_to_copy, detail::transfer_direction direction, detail::global_data_struct<1> global_data) { global_data.log_message(__func__, "storage scheme: INTERLEAVED_COMPLEX"); - detail::strided_view ptr_strided_real_view{ptr_view, std::get<0>(ptr_view_strides_offsets), - std::get<1>(ptr_view_strides_offsets)}; - detail::strided_view ptr_strided_imag_view{ptr_imag_view, std::get<0>(ptr_imag_view_strides_offsets), - std::get<1>(ptr_imag_view_strides_offsets)}; + detail::strided_view ptr_strided_real_view{ptr_view, ptr_view_strides_offsets.strides, + ptr_view_strides_offsets.offsets}; + detail::strided_view ptr_strided_imag_view{ptr_imag_view, ptr_imag_view_strides_offsets.strides, + ptr_imag_view_strides_offsets.offsets}; detail::strided_view priv_strided_real_view{priv, 2}; detail::strided_view priv_strided_imag_view{priv, 2, 1}; if (direction == detail::transfer_direction::LOCAL_TO_PRIVATE) { diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 25038527..ad84cf47 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -26,7 +26,7 @@ #include "portfft/defines.hpp" #include "portfft/enums.hpp" #include "portfft/traits.hpp" -#include "subgroup.hpp" +#include "subgroup_ct.hpp" namespace portfft { diff --git a/src/portfft/defines.hpp b/src/portfft/defines.hpp index 75afb240..15f087b1 100644 --- a/src/portfft/defines.hpp +++ b/src/portfft/defines.hpp @@ -54,8 +54,14 @@ using IdxGlobal = std::int64_t; * @tparam Type Type of elements * @tparam N Number of elements in each of the two arrays */ -template -using so_array = std::array, 2>; +template +struct stride_offset_struct { + std::array strides; + std::array offsets; + __attribute__((always_inline)) inline constexpr stride_offset_struct(const std::array strides, + const std::array offsets) + : strides(strides), offsets(offsets) {} +}; } // namespace portfft diff --git a/src/portfft/descriptor_validation.hpp b/src/portfft/descriptor_validation.hpp index 9c4e421a..b192b3a4 100644 --- a/src/portfft/descriptor_validation.hpp +++ b/src/portfft/descriptor_validation.hpp @@ -24,7 +24,7 @@ #include #include "common/exceptions.hpp" -#include "common/subgroup.hpp" +#include "common/subgroup_ct.hpp" #include "enums.hpp" #include "utils.hpp" diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index b4ba5a76..eebd3c27 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -26,7 +26,7 @@ #include #include "portfft/common/global.hpp" -#include "portfft/common/subgroup.hpp" +#include "portfft/common/subgroup_ct.hpp" #include "portfft/defines.hpp" #include "portfft/enums.hpp" #include "portfft/specialization_constant.hpp" diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 2a6d81de..44e2c040 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -25,7 +25,8 @@ #include "portfft/common/helpers.hpp" #include "portfft/common/logging.hpp" #include "portfft/common/memory_views.hpp" -#include "portfft/common/subgroup.hpp" +#include "portfft/common/subgroup_bluestein.hpp" +#include "portfft/common/subgroup_ct.hpp" #include "portfft/common/transfers.hpp" #include "portfft/defines.hpp" #include "portfft/descriptor.hpp" @@ -223,13 +224,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag const Idx fft_element = 2 * id_of_wi_in_fft * factor_wi; local_private_strided_copy<1, Idx>( loc_view, priv, - {{{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}}, + {{max_num_batches_local_mem}, {fft_element * max_num_batches_local_mem + 2 * fft_idx_in_local}}, factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } else { local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, - {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}}, - {{{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}}, + {{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local}}, + {{1, max_num_batches_local_mem}, {id_of_wi_in_fft * factor_wi, fft_idx_in_local + local_imag_offset}}, factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); @@ -237,10 +238,10 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag IdxGlobal modifier_offset = static_cast(n_reals_per_fft) * (i + static_cast(fft_idx_in_local)); if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { - sg_dft_compute(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, - conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, - loc_twiddles, scaling_factor, modifier_offset, id_of_wi_in_fft, factor_sg, - factor_wi, global_data); + sg_cooley_tukey(priv, wi_private_scratch, multiply_on_load, multiply_on_store, + conjugate_on_load, conjugate_on_store, apply_scale_factor, load_modifier_data, + store_modifier_data, loc_twiddles, scaling_factor, modifier_offset, + id_of_wi_in_fft, factor_sg, factor_wi, global_data); } else { sg_bluestein_batch_interleaved( priv, wi_private_scratch, loc_view, load_modifier_data, store_modifier_data, loc_twiddles, @@ -261,11 +262,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag static_cast(2 * fft_size) + static_cast(2 * id_of_wi_in_fft); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - so_array output_stride_offset{{{static_cast(factor_sg)}, {output_offset}}}; + stride_offset_struct output_stride_offset{{static_cast(factor_sg)}, + {output_offset}}; local_private_strided_copy<1, IdxGlobal>(output, priv, output_stride_offset, factor_wi, detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } else { - so_array output_stride_offset{{{static_cast(factor_sg)}, {output_offset / 2}}}; + stride_offset_struct output_stride_offset{{static_cast(factor_sg)}, + {output_offset / 2}}; local_private_strided_copy<1, IdxGlobal>(output, output_imag, priv, output_stride_offset, output_stride_offset, factor_wi, detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); @@ -280,14 +283,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag // Store back to local memory only if (storage == complex_storage::INTERLEAVED_COMPLEX) { local_private_strided_copy<2, Idx>( - loc_view, priv, - {{{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}}, factor_wi, - detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); + loc_view, priv, {{factor_sg, max_num_batches_local_mem}, {2 * id_of_wi_in_fft, 2 * fft_idx_in_local}}, + factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } else { local_private_strided_copy<2, Idx>( loc_view, loc_view, priv, - {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}}, - {{{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}}, + {{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local}}, + {{factor_sg, max_num_batches_local_mem}, {id_of_wi_in_fft, fft_idx_in_local + local_imag_offset}}, factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } } @@ -390,24 +392,24 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working) { global_data.log_message_global(__func__, "loading non-transposed data from local to private memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { - local_private_strided_copy<1, Idx>( - loc_view, priv, {{{1}, {subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi}}}, factor_wi, - detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); + local_private_strided_copy<1, Idx>(loc_view, priv, + {{1}, {subgroup_id * n_reals_per_sg + subgroup_local_id * n_reals_per_wi}}, + factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } else { local_private_strided_copy<1, Idx>( - loc_view, loc_view, priv, {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi}}}, - {{{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}}}, factor_wi, + loc_view, loc_view, priv, {{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi}}, + {{1}, {subgroup_id * n_cplx_per_sg + subgroup_local_id * factor_wi + local_imag_offset}}, factor_wi, detail::transfer_direction::LOCAL_TO_PRIVATE, global_data); } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } sycl::group_barrier(global_data.sg); if (algorithm == detail::fft_algorithm::COOLEY_TUKEY) { - sg_dft_compute(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, - conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, - loc_twiddles, scaling_factor, - static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - id_of_wi_in_fft, factor_sg, factor_wi, global_data); + sg_cooley_tukey(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, + conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, + loc_twiddles, scaling_factor, + static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), + id_of_wi_in_fft, factor_sg, factor_wi, global_data); } else { Idx loc_offset_store_view; Idx loc_offset_load_view; @@ -438,17 +440,17 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag IdxGlobal output_offset = i * static_cast(n_reals_per_sg) + static_cast(id_of_fft_in_sg * n_reals_per_fft) + static_cast(id_of_wi_in_fft * 2); - local_private_strided_copy<1, IdxGlobal>( - output, priv, {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, - detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); + local_private_strided_copy<1, IdxGlobal>(output, priv, + {{static_cast(factor_sg)}, {output_offset}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } else { IdxGlobal output_offset = i * static_cast(n_cplx_per_sg) + static_cast(id_of_fft_in_sg * fft_size) + static_cast(id_of_wi_in_fft); - local_private_strided_copy<1, IdxGlobal>( - output, output_imag, priv, {{{static_cast(factor_sg)}, {output_offset}}}, - {{{static_cast(factor_sg)}, {output_offset}}}, factor_wi, - detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); + local_private_strided_copy<1, IdxGlobal>(output, output_imag, priv, + {{static_cast(factor_sg)}, {output_offset}}, + {{static_cast(factor_sg)}, {output_offset}}, factor_wi, + detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } } } else if (is_output_batch_interleaved && algorithm == detail::fft_algorithm::COOLEY_TUKEY) { @@ -456,13 +458,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "Storing data from private to Global with batch interleaved layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { local_private_strided_copy<2, IdxGlobal>(output, priv, - {{{static_cast(factor_sg), n_transforms}, - {static_cast(2 * id_of_wi_in_fft), 2 * i}}}, + {{static_cast(factor_sg), n_transforms}, + {static_cast(2 * id_of_wi_in_fft), 2 * i}}, factor_wi, detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); } else { - so_array global_stride_offset{ - {{static_cast(factor_sg), n_transforms}, {static_cast(id_of_wi_in_fft), i}}}; + stride_offset_struct global_stride_offset{{static_cast(factor_sg), n_transforms}, + {static_cast(id_of_wi_in_fft), i}}; local_private_strided_copy<2, IdxGlobal>(output, output_imag, priv, global_stride_offset, global_stride_offset, factor_wi, detail::transfer_direction::PRIVATE_TO_GLOBAL, global_data); @@ -475,12 +477,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (storage == complex_storage::INTERLEAVED_COMPLEX) { Idx loc_view_offset = subgroup_id * n_reals_per_sg + id_of_fft_in_sg * n_reals_per_fft + 2 * id_of_wi_in_fft; - local_private_strided_copy<1, Idx>(loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, factor_wi, + local_private_strided_copy<1, Idx>(loc_view, priv, {{factor_sg}, {loc_view_offset}}, factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } else { Idx loc_view_offset = subgroup_id * n_cplx_per_sg + id_of_fft_in_sg * fft_size + id_of_wi_in_fft; - local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{{factor_sg}, {loc_view_offset}}}, - {{{factor_sg}, {loc_view_offset + local_imag_offset}}}, factor_wi, + local_private_strided_copy<1, Idx>(loc_view, loc_view, priv, {{factor_sg}, {loc_view_offset}}, + {{factor_sg}, {loc_view_offset + local_imag_offset}}, factor_wi, detail::transfer_direction::PRIVATE_TO_LOCAL, global_data); } } @@ -535,7 +537,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag loc_view_offset, loc_view_imag_offset, n_ffts_worked_on_by_sg, i, n_transforms, global_data); } } else { - // TODO: Blustein Strided Copy + // TODO: Bluestein Strided Copy } } sycl::group_barrier(global_data.sg); @@ -555,8 +557,9 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn Idx factor_wi = kernel_data.factors[0]; Idx factor_sg = kernel_data.factors[1]; std::size_t twiddles_alloc_size = [&]() { - if (dimension_data.is_prime) { + if (dimension_data.algorithm == detail::fft_algorithm::BLUESTEIN) { // sg twiddles + load_modifiers + store_modifiers + // NOLINTNEXTLINE return 6 * dimension_data.length; } return 2 * dimension_data.length; @@ -577,7 +580,7 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn host_twiddles[static_cast((j + factor_wi) * factor_sg + i)] = twiddle.imag(); } } - if (dimension_data.is_prime) { + if (dimension_data.algorithm == detail::fft_algorithm::BLUESTEIN) { detail::populate_bluestein_input_modifiers(host_twiddles.data() + 2 * factor_sg * factor_wi, dimension_data.committed_length, dimension_data.length); detail::populate_fft_chirp_signal(host_twiddles.data() + 4 * factor_sg * factor_wi, diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index 6b6c54b9..a0855882 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -260,7 +260,7 @@ detail::layout get_layout(const Descriptor& desc, direction dir) { * @return The padded input size for which the FFT transform will run */ inline IdxGlobal get_bluestein_padded_size(IdxGlobal input_size) { - return static_cast(std::pow(2, ceil(log(static_cast(2 * input_size)) / log(2.0)))); + return static_cast(std::pow(2, ceil(std::log2(2 * input_size)))); } } // namespace detail diff --git a/test/unit_test/fft_test_utils.hpp b/test/unit_test/fft_test_utils.hpp index 3503412b..eac9574c 100644 --- a/test/unit_test/fft_test_utils.hpp +++ b/test/unit_test/fft_test_utils.hpp @@ -273,7 +273,6 @@ std::enable_if_t check_fft( const std::vector& host_reference_output, const std::vector& host_input_imag, std::vector& host_output_imag, const std::vector& host_reference_output_imag, double tolerance) { - // std::cout << "I AM IN CHECK FFT USM " << std::endl; auto committed_descriptor = desc.commit(queue); const bool is_oop = desc.placement == placement::OUT_OF_PLACE; From cfd2ab8e9a3a662db941574d909e62117709c847 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 18 Mar 2024 14:51:51 +0000 Subject: [PATCH 20/22] updated doxygens --- src/portfft/common/transfers.hpp | 4 ++-- src/portfft/defines.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/portfft/common/transfers.hpp b/src/portfft/common/transfers.hpp index 0efdeda2..c5cc6bd6 100644 --- a/src/portfft/common/transfers.hpp +++ b/src/portfft/common/transfers.hpp @@ -699,9 +699,9 @@ PORTFFT_INLINE void local_private_strided_copy(PtrView& ptr_view, T* priv, * @param ptr_view View of the local / global memory containing the real component of the data * @param ptr_imag_view View of the local / global memory containing the imaginary component of the data * @param priv Pointer to the private memory array - * @param ptr_view_strides_offsets An array of 2 arrays containing PtrViewNDim elements of IdxType, containing strides + * @param ptr_view_strides_offsets Struct containing strides * and offsets for the strided view to be constructed for the local / global memory containing the real part of the data - * @param ptr_imag_view_strides_offsets An array of 2 arrays containing PtrViewNDim elements of IdxType, containing + * @param ptr_imag_view_strides_offsets Struct containing * strides and offsets for the strided view to be constructed for the local / global memory containing the imaginary * part of the data * @param num_elements_to_copy Number of elements to copy diff --git a/src/portfft/defines.hpp b/src/portfft/defines.hpp index 15f087b1..5397b248 100644 --- a/src/portfft/defines.hpp +++ b/src/portfft/defines.hpp @@ -50,7 +50,7 @@ using Idx = std::int32_t; using IdxGlobal = std::int64_t; /** - * An array of 2 arrays containing N elements of Type, containing strides (s) and offset (o) for a view + * Struct containing the strides and offsets for a view * @tparam Type Type of elements * @tparam N Number of elements in each of the two arrays */ From d0e705ded8a70fe098693ba541d7e9b954197038 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 18 Mar 2024 15:43:55 +0000 Subject: [PATCH 21/22] addressed missed comments from first round of review --- src/portfft/common/subgroup_bluestein.hpp | 12 ++++++++---- src/portfft/common/subgroup_ct.hpp | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/portfft/common/subgroup_bluestein.hpp b/src/portfft/common/subgroup_bluestein.hpp index 1c8be538..3466410b 100644 --- a/src/portfft/common/subgroup_bluestein.hpp +++ b/src/portfft/common/subgroup_bluestein.hpp @@ -31,8 +31,10 @@ namespace portfft { /** - * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, - * when the data in both local and global memory is in packed format,when the storage scheme is INTERLEAVED_COMPLEX + * Function to copy data between local and global memory as required by the subgroup level bluestein algorithm + * when the data in both local and global memory is in packed format,when the storage scheme is INTERLEAVED_COMPLEX. + * The number of complex elements per transform in the global memory will be equal to the committed_length, and the + * number of complex elements per transform in local memory will be equal to the padded length. * * @tparam SubgroupSize Subgroup size * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or @@ -57,7 +59,7 @@ PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( TIn global_ptr, LocView& loc_view, Idx committed_size, Idx fft_size, IdxGlobal global_ptr_offset, Idx loc_offset, Idx n_ffts_in_sg, IdxGlobal transform_id, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && ((i + transform_id) < n_transforms); i++) { + for (Idx i = 0; i < n_ffts_in_sg && i + transform_id < n_transforms; i++) { local_global_packed_copy( global_ptr, loc_view, global_ptr_offset + static_cast(2 * i * committed_size), 2 * i * fft_size + loc_offset, 2 * committed_size, global_data); @@ -67,6 +69,8 @@ PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( /** * Function to copy data between local and global memory as required by the subgroup level Bluestein algorithm, * when the data in both local and global memory is in packed format,when the storage scheme is SPLIT_COMPLEX + * The number of complex elements per transform in the global memory will be equal to the committed_length, and the + * number of complex elements per transform in local memory will be equal to the padded length. * * @tparam SubgroupSize Subgroup size * @tparam Direction Direction Direction of the copy, expected to be either transfer_direction::LOCAL_TO_GLOBAL or @@ -95,7 +99,7 @@ PORTFFT_INLINE void subgroup_impl_bluestein_local_global_packed_copy( IdxGlobal global_ptr_offset, Idx loc_offset, Idx local_imag_offset, Idx n_ffts_in_sg, IdxGlobal transform_id, IdxGlobal n_transforms, detail::global_data_struct<1>& global_data) { PORTFFT_UNROLL - for (Idx i = 0; i < n_ffts_in_sg && (i + transform_id < n_transforms); i++) { + for (Idx i = 0; i < n_ffts_in_sg && i + transform_id < n_transforms; i++) { local_global_packed_copy( global_ptr, global_imag_ptr, loc_view, static_cast(i * committed_size) + global_ptr_offset, i * fft_size + loc_offset, local_imag_offset, committed_size, global_data); diff --git a/src/portfft/common/subgroup_ct.hpp b/src/portfft/common/subgroup_ct.hpp index 831f8586..291b1f54 100644 --- a/src/portfft/common/subgroup_ct.hpp +++ b/src/portfft/common/subgroup_ct.hpp @@ -316,6 +316,7 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles * call to sg_dft * Applying the store modifiers * Taking conjugate of the output + * Applying the scaling factor * * @tparam SubgroupSize Subgroup Size * @tparam T Scalar Type From 823b84fb8140045bbf61a30f4b26d1b78b1f65e4 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 19 Mar 2024 16:34:20 +0000 Subject: [PATCH 22/22] prevent OOB read from global memory --- src/portfft/common/subgroup_bluestein.hpp | 8 ++--- src/portfft/common/subgroup_ct.hpp | 35 +++++++++++-------- .../dispatcher/subgroup_dispatcher.hpp | 4 +-- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/portfft/common/subgroup_bluestein.hpp b/src/portfft/common/subgroup_bluestein.hpp index 3466410b..9ea3eeec 100644 --- a/src/portfft/common/subgroup_bluestein.hpp +++ b/src/portfft/common/subgroup_bluestein.hpp @@ -149,7 +149,7 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved( priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, store_modifier, twiddles_loc, static_cast(1. / (static_cast(factor_sg * factor_wi))), 0, id_of_wi_in_fft, - factor_sg, factor_wi, global_data); + factor_sg, factor_wi, wi_working, global_data); // TODO: Currently local memory is being used to load the data back in natural order for the backward phase, as the // result of sg_dft is transposed. However, the ideal way to this is using shuffles. Implement a batched matrix @@ -190,7 +190,7 @@ PORTFFT_INLINE void sg_bluestein_batch_interleaved( detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), load_modifier, twiddles_loc, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, - global_data); + wi_working, global_data); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { global_data.log_message(__func__, "Applying complex conjugate on the output"); @@ -244,7 +244,7 @@ void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddle priv, priv_scratch, detail::elementwise_multiply::APPLIED, detail::elementwise_multiply::APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, detail::apply_scale_factor::APPLIED, load_modifier, store_modifier, loc_twiddles, static_cast(1. / static_cast(factor_sg * factor_wi)), 0, id_of_wi_in_fft, - factor_sg, factor_wi, global_data); + factor_sg, factor_wi, wi_working, global_data); if (wi_working) { global_data.log_message(__func__, "storing result of the forward phase back to local memory"); @@ -276,7 +276,7 @@ void sg_bluestein_packed(T* priv, T* priv_scratch, LocView& loc_view, LocTwiddle detail::elementwise_multiply::APPLIED, detail::complex_conjugate::APPLIED, detail::complex_conjugate::APPLIED, scale_applied, static_cast(nullptr), load_modifier, loc_twiddles, scale_factor, 0, id_of_wi_in_fft, factor_sg, factor_wi, - global_data); + wi_working, global_data); if (conjugate_on_store == detail::complex_conjugate::APPLIED) { global_data.log_message(__func__, "Applying complex conjugate on the output"); detail::conjugate_inplace(priv, factor_wi); diff --git a/src/portfft/common/subgroup_ct.hpp b/src/portfft/common/subgroup_ct.hpp index 291b1f54..52855b04 100644 --- a/src/portfft/common/subgroup_ct.hpp +++ b/src/portfft/common/subgroup_ct.hpp @@ -338,6 +338,7 @@ void sg_calc_twiddles(Idx factor_sg, Idx factor_wi, Idx n, Idx k, T* sg_twiddles * @param id_of_wi_in_fft workitem id withing the fft * @param factor_sg Number of workitems participating for one transform * @param factor_wi Number of complex elements per workitem for each transform + * @param wi_working Whether or not the workitem participates in the data transfers * @param global_data global_data_struct associated with the kernel launch */ template @@ -348,7 +349,7 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element detail::apply_scale_factor scale_factor_applied, const T* load_modifier_data, const T* store_modifier_data, LocView& twiddles_loc_view, T scale_factor, IdxGlobal modifier_start_offset, Idx id_of_wi_in_fft, Idx factor_sg, Idx factor_wi, - detail::global_data_struct<1>& global_data) { + bool wi_working, detail::global_data_struct<1>& global_data) { using vec2_t = sycl::vec; vec2_t modifier_vec; if (conjugate_on_load == detail::complex_conjugate::APPLIED) { @@ -356,13 +357,15 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element detail::conjugate_inplace(priv, factor_wi); } if (apply_load_modifier == detail::elementwise_multiply::APPLIED) { - global_data.log_message(__func__, "Applying load modifiers"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - modifier_vec = *reinterpret_cast( - &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); - detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], - priv[2 * j + 1]); + if (wi_working) { + global_data.log_message(__func__, "Applying load modifiers"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &load_modifier_data[modifier_start_offset + 2 * factor_wi * id_of_wi_in_fft + 2 * j]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } } } sg_dft(priv, global_data.sg, factor_wi, factor_sg, twiddles_loc_view, private_scratch); @@ -373,13 +376,15 @@ PORTFFT_INLINE void sg_cooley_tukey(T* priv, T* private_scratch, detail::element } if (apply_store_modifier == detail::elementwise_multiply::APPLIED) { - global_data.log_message(__func__, "Applying store modifiers"); - PORTFFT_UNROLL - for (Idx j = 0; j < factor_wi; j++) { - modifier_vec = *reinterpret_cast( - &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); - detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], - priv[2 * j + 1]); + if (wi_working) { + global_data.log_message(__func__, "Applying store modifiers"); + PORTFFT_UNROLL + for (Idx j = 0; j < factor_wi; j++) { + modifier_vec = *reinterpret_cast( + &store_modifier_data[modifier_start_offset + 2 * j * factor_sg + 2 * id_of_wi_in_fft]); + detail::multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); + } } } diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 44e2c040..756380d9 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -241,7 +241,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag sg_cooley_tukey(priv, wi_private_scratch, multiply_on_load, multiply_on_store, conjugate_on_load, conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, loc_twiddles, scaling_factor, modifier_offset, - id_of_wi_in_fft, factor_sg, factor_wi, global_data); + id_of_wi_in_fft, factor_sg, factor_wi, working_inner, global_data); } else { sg_bluestein_batch_interleaved( priv, wi_private_scratch, loc_view, load_modifier_data, store_modifier_data, loc_twiddles, @@ -409,7 +409,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag conjugate_on_store, apply_scale_factor, load_modifier_data, store_modifier_data, loc_twiddles, scaling_factor, static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), - id_of_wi_in_fft, factor_sg, factor_wi, global_data); + id_of_wi_in_fft, factor_sg, factor_wi, working, global_data); } else { Idx loc_offset_store_view; Idx loc_offset_load_view;