Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 56 additions & 115 deletions src/portfft/committed_descriptor_impl.hpp

Large diffs are not rendered by default.

39 changes: 17 additions & 22 deletions src/portfft/common/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
* Device function responsible for calling the corresponding sub-implementation
*
* @tparam Scalar Scalar type
* @tparam LayoutIn Input layout
* @tparam LayoutOut Output layout
* @tparam SubgroupSize Subgroup size
* @param input input pointer
* @param output output pointer
Expand All @@ -134,7 +132,7 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
* @param global_data global data
* @param kh kernel handler
*/
template <typename Scalar, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize>
template <typename Scalar, Idx SubgroupSize>
PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag,
const Scalar* implementation_twiddles, const Scalar* store_modifier_data,
Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc,
Expand All @@ -156,16 +154,16 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc
batch_size, global_data, kh, static_cast<const Scalar*>(nullptr),
store_modifier_data, static_cast<Scalar*>(nullptr), store_modifier_loc);
} else if (level == detail::level::SUBGROUP) {
subgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
kh, static_cast<const Scalar*>(nullptr), store_modifier_data, static_cast<Scalar*>(nullptr),
store_modifier_loc);
subgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
twiddles_loc, batch_size, implementation_twiddles, global_data, kh,
static_cast<const Scalar*>(nullptr), store_modifier_data,
static_cast<Scalar*>(nullptr), store_modifier_loc);
} else if (level == detail::level::WORKGROUP) {
workgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
kh, static_cast<Scalar*>(nullptr), store_modifier_data);
workgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
twiddles_loc, batch_size, implementation_twiddles, global_data, kh,
static_cast<Scalar*>(nullptr), store_modifier_data);
}
sycl::group_barrier(global_data.it.get_group());
}
Expand Down Expand Up @@ -277,8 +275,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
* Prepares the launch of fft compute at a particular level
* @tparam Scalar Scalar type
* @tparam Domain Domain of FFT
* @tparam LayoutIn Input layout
* @tparam LayoutOut output layout
* @tparam SubgroupSize subgroup size
* @tparam TIn input type
* @param kd_struct associated kernel data struct with the factor
Expand All @@ -304,8 +300,7 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
* @param queue queue
* @return vector events, one for each batch in l2
*/
template <typename Scalar, domain Domain, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
typename TIn>
template <typename Scalar, domain Domain, Idx SubgroupSize, typename TIn>
std::vector<sycl::event> compute_level(
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn& input,
Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr,
Expand Down Expand Up @@ -380,7 +375,7 @@ std::vector<sycl::event> compute_level(
#endif
PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size",
local_range);
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, LayoutIn, LayoutOut, SubgroupSize>>(
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, SubgroupSize>>(
sycl::nd_range<1>(sycl::range<1>(static_cast<std::size_t>(global_range)),
sycl::range<1>(static_cast<std::size_t>(local_range))),
[=
Expand All @@ -394,11 +389,11 @@ std::vector<sycl::event> compute_level(
s, global_logging_config,
#endif
it};
dispatch_level<Scalar, LayoutIn, LayoutOut, SubgroupSize>(
&in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset,
offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size,
global_data, kh);
dispatch_level<Scalar, SubgroupSize>(&in_acc_or_usm[0] + input_batch_offset, offset_output,
&in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag,
subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple,
inner_batches, inclusive_scan, batch_size, global_data, kh);
});
}));
}
Expand Down
17 changes: 8 additions & 9 deletions src/portfft/common/workgroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ namespace detail {
/**
* Calculate all dfts in one dimension of the data stored in local memory.
*
* @tparam LayoutIn Input Layout
* @tparam SubgroupSize Size of the subgroup
* @tparam LocalT The type of the local view
* @tparam T Scalar type
Expand All @@ -73,7 +72,7 @@ namespace detail {
* @param stride_within_dft Stride between elements of each DFT - also the number of the DFTs in the inner dimension
* @param ndfts_in_outer_dimension Number of DFTs in outer dimension
* @param storage complex storage: interleaved or split
* @param layout_in Input Layout
* @param input_layout the layout of the input data of the transforms
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation.
* @param ApplyScaleFactor Whether or not the scale factor is applied
Expand All @@ -86,7 +85,7 @@ __attribute__((always_inline)) inline void dimension_dft(
LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem,
Idx batch_num_in_local, const T* load_modifier_data, const T* store_modifier_data, IdxGlobal batch_num_in_kernel,
Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, complex_storage storage,
detail::layout layout_in, detail::elementwise_multiply multiply_on_load,
detail::layout input_layout, detail::elementwise_multiply multiply_on_load,
detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor,
detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store,
global_data_struct<1> global_data) {
Expand Down Expand Up @@ -149,7 +148,7 @@ __attribute__((always_inline)) inline void dimension_dft(
working = working && static_cast<Idx>(global_data.sg.get_local_linear_id()) < max_working_tid_in_sg;
}
if (working) {
if (layout_in == detail::layout::BATCH_INTERLEAVED) {
if (input_layout == detail::layout::BATCH_INTERLEAVED) {
global_data.log_message_global(__func__, "loading transposed data from local to private memory");
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
detail::strided_view local_view{
Expand Down Expand Up @@ -249,7 +248,7 @@ __attribute__((always_inline)) inline void dimension_dft(
}
}
global_data.log_dump_private("data in registers after computation:", priv, 2 * fact_wi);
if (layout_in == detail::layout::BATCH_INTERLEAVED) {
if (input_layout == detail::layout::BATCH_INTERLEAVED) {
global_data.log_message_global(__func__, "storing transposed data from private to local memory");
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
detail::strided_view local_view{
Expand Down Expand Up @@ -313,7 +312,7 @@ __attribute__((always_inline)) inline void dimension_dft(
* @param N Smaller factor of the Problem size
* @param M Larger factor of the problem size
* @param storage complex storage: interleaved or split
* @param layout_in Whether or not the input is transposed
* @param input_layout the layout of the input data of the transforms
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param multiply_on_store Whether the input data is multiplied with some data array after fft computation.
* @param apply_scale_factor Whether or not the scale factor is applied
Expand All @@ -325,7 +324,7 @@ template <Idx SubgroupSize, typename LocalT, typename T>
PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor,
Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel,
const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M,
complex_storage storage, detail::layout layout_in,
complex_storage storage, detail::layout input_layout,
detail::elementwise_multiply multiply_on_load,
detail::elementwise_multiply multiply_on_store,
detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load,
Expand All @@ -336,14 +335,14 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T
// column-wise DFTs
detail::dimension_dft<SubgroupSize, LocalT, T>(
loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data,
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, layout_in, multiply_on_load,
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, input_layout, multiply_on_load,
detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, conjugate_on_load,
detail::complex_conjugate::NOT_APPLIED, global_data);
sycl::group_barrier(global_data.it.get_group());
// row-wise DFTs, including twiddle multiplications and scaling
detail::dimension_dft<SubgroupSize, LocalT, T>(
loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local,
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, layout_in,
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, input_layout,
detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor,
detail::complex_conjugate::NOT_APPLIED, conjugate_on_store, global_data);
global_data.log_message_global(__func__, "exited");
Expand Down
Loading