Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1012 lines (907 sloc)
37.6 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** @addtogroup dft | |
* @{ | |
*/ | |
/* | |
Copyright (C) 2016 D Levin (https://www.kfrlib.com) | |
This file is part of KFR | |
KFR is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
KFR is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with KFR. | |
If GPL is not suitable for your project, you must purchase a commercial license to use KFR. | |
Buying a commercial license is mandatory as soon as you develop commercial activities without | |
disclosing the source code of your own applications. | |
See https://www.kfrlib.com for details. | |
*/ | |
#pragma once | |
#include "../base/complex.hpp" | |
#include "../base/constants.hpp" | |
#include "../base/memory.hpp" | |
#include "../base/read_write.hpp" | |
#include "../base/small_buffer.hpp" | |
#include "../base/vec.hpp" | |
#include "../cometa/string.hpp" | |
#include "bitrev.hpp" | |
#include "ft.hpp" | |
#pragma clang diagnostic push | |
#if CMT_HAS_WARNING("-Wshadow") | |
#pragma clang diagnostic ignored "-Wshadow" | |
#endif | |
namespace kfr | |
{ | |
template <typename T> | |
struct dft_stage | |
{ | |
size_t stage_size = 0; | |
size_t data_size = 0; | |
size_t temp_size = 0; | |
u8* data = nullptr; | |
size_t repeats = 1; | |
size_t out_offset = 0; | |
const char* name; | |
bool recursion = false; | |
void initialize(size_t size) { do_initialize(size); } | |
KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp) { do_execute(out, in, temp); } | |
virtual ~dft_stage() {} | |
protected: | |
virtual void do_initialize(size_t) {} | |
virtual void do_execute(complex<T>*, const complex<T>*, u8* temp) = 0; | |
}; | |
#pragma clang diagnostic push | |
#if CMT_HAS_WARNING("-Wassume") | |
#pragma clang diagnostic ignored "-Wassume" | |
#endif | |
namespace internal | |
{ | |
template <size_t width, bool inverse, typename T> | |
KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, cfalse_t /*split_format*/, cbool_t<inverse>, | |
cvec<T, width> w, cvec<T, width> tw) | |
{ | |
cvec<T, width> b1 = w * dupeven(tw); | |
w = swap<2>(w); | |
if (inverse) | |
tw = -(tw); | |
w = subadd(b1, w * dupodd(tw)); | |
return w; | |
} | |
template <size_t width, bool use_br2, bool inverse, bool aligned, typename T> | |
KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, cfalse_t, cfalse_t, cfalse_t, cbool_t<use_br2>, | |
cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, const complex<T>* in, | |
const complex<T>* twiddle) | |
{ | |
const size_t N4 = N / 4; | |
cvec<T, width> w1, w2, w3; | |
cvec<T, width> sum02, sum13, diff02, diff13; | |
cvec<T, width> a0, a1, a2, a3; | |
a0 = cread<width, aligned>(in + 0); | |
a2 = cread<width, aligned>(in + N4 * 2); | |
sum02 = a0 + a2; | |
a1 = cread<width, aligned>(in + N4); | |
a3 = cread<width, aligned>(in + N4 * 3); | |
sum13 = a1 + a3; | |
cwrite<width, aligned>(out, sum02 + sum13); | |
w2 = sum02 - sum13; | |
cwrite<width, aligned>( | |
out + N4 * (use_br2 ? 1 : 2), | |
radix4_apply_twiddle(csize<width>, cfalse, cbool<inverse>, w2, cread<width, true>(twiddle + width))); | |
diff02 = a0 - a2; | |
diff13 = a1 - a3; | |
if (inverse) | |
{ | |
diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T())); | |
diff13 = swap<2>(diff13); | |
} | |
else | |
{ | |
diff13 = swap<2>(diff13); | |
diff13 = (diff13 ^ broadcast<width * 2, T>(T(), -T())); | |
} | |
w1 = diff02 + diff13; | |
cwrite<width, aligned>( | |
out + N4 * (use_br2 ? 2 : 1), | |
radix4_apply_twiddle(csize<width>, cfalse, cbool<inverse>, w1, cread<width, true>(twiddle + 0))); | |
w3 = diff02 - diff13; | |
cwrite<width, aligned>(out + N4 * 3, radix4_apply_twiddle(csize<width>, cfalse, cbool<inverse>, w3, | |
cread<width, true>(twiddle + width * 2))); | |
} | |
template <size_t width, bool inverse, typename T> | |
KFR_SINTRIN cvec<T, width> radix4_apply_twiddle(csize_t<width>, ctrue_t /*split_format*/, cbool_t<inverse>, | |
cvec<T, width> w, cvec<T, width> tw) | |
{ | |
vec<T, width> re1, im1, twre, twim; | |
split(w, re1, im1); | |
split(tw, twre, twim); | |
const vec<T, width> b1re = re1 * twre; | |
const vec<T, width> b1im = im1 * twre; | |
if (inverse) | |
w = concat(b1re + im1 * twim, b1im - re1 * twim); | |
else | |
w = concat(b1re - im1 * twim, b1im + re1 * twim); | |
return w; | |
} | |
template <size_t width, bool splitout, bool splitin, bool use_br2, bool inverse, bool aligned, typename T> | |
KFR_SINTRIN void radix4_body(size_t N, csize_t<width>, ctrue_t, cbool_t<splitout>, cbool_t<splitin>, | |
cbool_t<use_br2>, cbool_t<inverse>, cbool_t<aligned>, complex<T>* out, | |
const complex<T>* in, const complex<T>* twiddle) | |
{ | |
const size_t N4 = N / 4; | |
cvec<T, width> w1, w2, w3; | |
constexpr bool read_split = !splitin && splitout; | |
constexpr bool write_split = splitin && !splitout; | |
vec<T, width> re0, im0, re1, im1, re2, im2, re3, im3; | |
split(cread_split<width, aligned, read_split>(in + N4 * 0), re0, im0); | |
split(cread_split<width, aligned, read_split>(in + N4 * 1), re1, im1); | |
split(cread_split<width, aligned, read_split>(in + N4 * 2), re2, im2); | |
split(cread_split<width, aligned, read_split>(in + N4 * 3), re3, im3); | |
const vec<T, width> sum02re = re0 + re2; | |
const vec<T, width> sum02im = im0 + im2; | |
const vec<T, width> sum13re = re1 + re3; | |
const vec<T, width> sum13im = im1 + im3; | |
cwrite_split<width, aligned, write_split>(out, concat(sum02re + sum13re, sum02im + sum13im)); | |
w2 = concat(sum02re - sum13re, sum02im - sum13im); | |
cwrite_split<width, aligned, write_split>( | |
out + N4 * (use_br2 ? 1 : 2), | |
radix4_apply_twiddle(csize<width>, ctrue, cbool<inverse>, w2, cread<width, true>(twiddle + width))); | |
const vec<T, width> diff02re = re0 - re2; | |
const vec<T, width> diff02im = im0 - im2; | |
const vec<T, width> diff13re = re1 - re3; | |
const vec<T, width> diff13im = im1 - im3; | |
(inverse ? w1 : w3) = concat(diff02re - diff13im, diff02im + diff13re); | |
(inverse ? w3 : w1) = concat(diff02re + diff13im, diff02im - diff13re); | |
cwrite_split<width, aligned, write_split>( | |
out + N4 * (use_br2 ? 2 : 1), | |
radix4_apply_twiddle(csize<width>, ctrue, cbool<inverse>, w1, cread<width, true>(twiddle + 0))); | |
cwrite_split<width, aligned, write_split>(out + N4 * 3, | |
radix4_apply_twiddle(csize<width>, ctrue, cbool<inverse>, w3, | |
cread<width, true>(twiddle + width * 2))); | |
} | |
template <typename T> | |
CMT_NOINLINE cvec<T, 1> calculate_twiddle(size_t n, size_t size) | |
{ | |
if (n == 0) | |
{ | |
return make_vector(static_cast<T>(1), static_cast<T>(0)); | |
} | |
else if (n == size / 4) | |
{ | |
return make_vector(static_cast<T>(0), static_cast<T>(-1)); | |
} | |
else if (n == size / 2) | |
{ | |
return make_vector(static_cast<T>(-1), static_cast<T>(0)); | |
} | |
else if (n == size * 3 / 4) | |
{ | |
return make_vector(static_cast<T>(0), static_cast<T>(1)); | |
} | |
else | |
{ | |
fbase kth = c_pi<fbase, 2> * (n / static_cast<fbase>(size)); | |
fbase tcos = +kfr::cos(kth); | |
fbase tsin = -kfr::sin(kth); | |
return make_vector(static_cast<T>(tcos), static_cast<T>(tsin)); | |
} | |
} | |
template <typename T, size_t width> | |
KFR_SINTRIN void initialize_twiddles_impl(complex<T>*& twiddle, size_t nn, size_t nnstep, size_t size, | |
bool split_format) | |
{ | |
vec<T, 2 * width> result = T(); | |
CMT_LOOP_UNROLL | |
for (size_t i = 0; i < width; i++) | |
{ | |
const cvec<T, 1> r = calculate_twiddle<T>(nn + nnstep * i, size); | |
result(i * 2) = r[0]; | |
result(i * 2 + 1) = r[1]; | |
} | |
if (split_format) | |
ref_cast<cvec<T, width>>(twiddle[0]) = splitpairs(result); | |
else | |
ref_cast<cvec<T, width>>(twiddle[0]) = result; | |
twiddle += width; | |
} | |
template <typename T, size_t width> | |
CMT_NOINLINE void initialize_twiddles(complex<T>*& twiddle, size_t stage_size, size_t size, bool split_format) | |
{ | |
size_t nnstep = size / stage_size; | |
CMT_LOOP_NOUNROLL | |
for (size_t n = 0; n < stage_size / 4; n += width) | |
{ | |
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 1, nnstep * 1, size, split_format); | |
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 2, nnstep * 2, size, split_format); | |
initialize_twiddles_impl<T, width>(twiddle, n * nnstep * 3, nnstep * 3, size, split_format); | |
} | |
} | |
template <typename T> | |
KFR_SINTRIN void prefetch_one(const complex<T>* in) | |
{ | |
#ifdef CMT_ARCH_X86 | |
__builtin_prefetch(ptr_cast<void>(in), 0, _MM_HINT_T0); | |
#else | |
__builtin_prefetch(ptr_cast<void>(in)); | |
#endif | |
} | |
template <typename T> | |
KFR_SINTRIN void prefetch_four(size_t stride, const complex<T>* in) | |
{ | |
#ifdef CMT_ARCH_X86 | |
__builtin_prefetch(ptr_cast<void>(in), 0, _MM_HINT_T0); | |
__builtin_prefetch(ptr_cast<void>(in + stride), 0, _MM_HINT_T0); | |
__builtin_prefetch(ptr_cast<void>(in + stride * 2), 0, _MM_HINT_T0); | |
__builtin_prefetch(ptr_cast<void>(in + stride * 3), 0, _MM_HINT_T0); | |
#else | |
__builtin_prefetch(ptr_cast<void>(in)); | |
__builtin_prefetch(ptr_cast<void>(in + stride)); | |
__builtin_prefetch(ptr_cast<void>(in + stride * 2)); | |
__builtin_prefetch(ptr_cast<void>(in + stride * 3)); | |
#endif | |
} | |
template <typename Ntype, size_t width, bool splitout, bool splitin, bool prefetch, bool use_br2, | |
bool inverse, bool aligned, typename T> | |
KFR_SINTRIN cfalse_t radix4_pass(Ntype N, size_t blocks, csize_t<width>, cbool_t<splitout>, cbool_t<splitin>, | |
cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, | |
complex<T>* out, const complex<T>* in, const complex<T>*& twiddle) | |
{ | |
constexpr static size_t prefetch_offset = width * 8; | |
const auto N4 = N / csize<4>; | |
const auto N43 = N4 * csize<3>; | |
CMT_ASSUME(blocks > 0); | |
CMT_ASSUME(N > 0); | |
CMT_ASSUME(N4 > 0); | |
CMT_LOOP_NOUNROLL for (size_t b = 0; b < blocks; b++) | |
{ | |
#pragma clang loop unroll_count(2) | |
for (size_t n2 = 0; n2 < N4; n2 += width) | |
{ | |
if (prefetch) | |
prefetch_four(N4, in + prefetch_offset); | |
radix4_body(N, csize<width>, cbool < splitout || splitin >, cbool<splitout>, cbool<splitin>, | |
cbool<use_br2>, cbool<inverse>, cbool<aligned>, out, in, twiddle + n2 * 3); | |
in += width; | |
out += width; | |
} | |
in += N43; | |
out += N43; | |
} | |
twiddle += N43; | |
return {}; | |
} | |
template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> | |
KFR_SINTRIN ctrue_t radix4_pass(csize_t<32>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, | |
cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, | |
complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) | |
{ | |
CMT_ASSUME(blocks > 0); | |
constexpr static size_t prefetch_offset = 32 * 4; | |
for (size_t b = 0; b < blocks; b++) | |
{ | |
if (prefetch) | |
prefetch_four(csize<64>, out + prefetch_offset); | |
cvec<T, 4> w0, w1, w2, w3, w4, w5, w6, w7; | |
split(cread<8, aligned>(out + 0), w0, w1); | |
split(cread<8, aligned>(out + 8), w2, w3); | |
split(cread<8, aligned>(out + 16), w4, w5); | |
split(cread<8, aligned>(out + 24), w6, w7); | |
butterfly8<4, inverse>(w0, w1, w2, w3, w4, w5, w6, w7); | |
w1 = cmul(w1, fixed_twiddle<T, 4, 32, 0, 1, inverse>); | |
w2 = cmul(w2, fixed_twiddle<T, 4, 32, 0, 2, inverse>); | |
w3 = cmul(w3, fixed_twiddle<T, 4, 32, 0, 3, inverse>); | |
w4 = cmul(w4, fixed_twiddle<T, 4, 32, 0, 4, inverse>); | |
w5 = cmul(w5, fixed_twiddle<T, 4, 32, 0, 5, inverse>); | |
w6 = cmul(w6, fixed_twiddle<T, 4, 32, 0, 6, inverse>); | |
w7 = cmul(w7, fixed_twiddle<T, 4, 32, 0, 7, inverse>); | |
cvec<T, 8> z0, z1, z2, z3; | |
transpose4x8(w0, w1, w2, w3, w4, w5, w6, w7, z0, z1, z2, z3); | |
butterfly4<8, inverse>(cfalse, z0, z1, z2, z3, z0, z1, z2, z3); | |
cwrite<32, aligned>(out, bitreverse<2>(concat(z0, z1, z2, z3))); | |
out += 32; | |
} | |
return {}; | |
} | |
template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> | |
KFR_SINTRIN ctrue_t radix4_pass(csize_t<8>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, | |
cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, | |
complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) | |
{ | |
CMT_ASSUME(blocks > 0); | |
constexpr static size_t prefetch_offset = width * 16; | |
for (size_t b = 0; b < blocks; b += 2) | |
{ | |
if (prefetch) | |
prefetch_one(out + prefetch_offset); | |
cvec<T, 8> vlo = cread<8, aligned>(out + 0); | |
cvec<T, 8> vhi = cread<8, aligned>(out + 8); | |
butterfly8<inverse>(vlo); | |
butterfly8<inverse>(vhi); | |
vlo = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vlo); | |
vhi = permutegroups<(2), 0, 4, 2, 6, 1, 5, 3, 7>(vhi); | |
cwrite<8, aligned>(out, vlo); | |
cwrite<8, aligned>(out + 8, vhi); | |
out += 16; | |
} | |
return {}; | |
} | |
template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> | |
KFR_SINTRIN ctrue_t radix4_pass(csize_t<16>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, | |
cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, | |
complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) | |
{ | |
CMT_ASSUME(blocks > 0); | |
constexpr static size_t prefetch_offset = width * 4; | |
#pragma clang loop unroll_count(2) | |
for (size_t b = 0; b < blocks; b += 2) | |
{ | |
if (prefetch) | |
prefetch_one(out + prefetch_offset); | |
cvec<T, 16> vlo = cread<16, aligned>(out); | |
cvec<T, 16> vhi = cread<16, aligned>(out + 16); | |
butterfly4<4, inverse>(vlo); | |
butterfly4<4, inverse>(vhi); | |
apply_twiddles4<0, 4, 4, inverse>(vlo); | |
apply_twiddles4<0, 4, 4, inverse>(vhi); | |
vlo = digitreverse4<2>(vlo); | |
vhi = digitreverse4<2>(vhi); | |
butterfly4<4, inverse>(vlo); | |
butterfly4<4, inverse>(vhi); | |
use_br2 ? cbitreverse_write(out, vlo) : cdigitreverse4_write(out, vlo); | |
use_br2 ? cbitreverse_write(out + 16, vhi) : cdigitreverse4_write(out + 16, vhi); | |
out += 32; | |
} | |
return {}; | |
} | |
template <size_t width, bool prefetch, bool use_br2, bool inverse, bool aligned, typename T> | |
KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfalse_t, cfalse_t, | |
cbool_t<use_br2>, cbool_t<prefetch>, cbool_t<inverse>, cbool_t<aligned>, | |
complex<T>* out, const complex<T>*, const complex<T>*& /*twiddle*/) | |
{ | |
constexpr static size_t prefetch_offset = width * 4; | |
CMT_ASSUME(blocks > 0); | |
CMT_LOOP_NOUNROLL | |
for (size_t b = 0; b < blocks; b += 4) | |
{ | |
if (prefetch) | |
prefetch_one(out + prefetch_offset); | |
cvec<T, 16> v16 = cdigitreverse4_read<16, aligned>(out); | |
butterfly4<4, inverse>(v16); | |
cdigitreverse4_write<aligned>(out, v16); | |
out += 4 * 4; | |
} | |
return {}; | |
} | |
template <typename T, bool splitin, bool is_even, bool inverse> | |
struct fft_stage_impl : dft_stage<T> | |
{ | |
fft_stage_impl(size_t stage_size) | |
{ | |
this->stage_size = stage_size; | |
this->repeats = 4; | |
this->recursion = true; | |
this->data_size = align_up(sizeof(complex<T>) * stage_size / 4 * 3, native_cache_alignment); | |
} | |
protected: | |
constexpr static bool prefetch = true; | |
constexpr static bool aligned = false; | |
constexpr static size_t width = vector_width<T, cpu_t::native>; | |
virtual void do_initialize(size_t size) override final | |
{ | |
complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
initialize_twiddles<T, width>(twiddle, this->stage_size, size, true); | |
} | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final | |
{ | |
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
if (splitin) | |
in = out; | |
const size_t stage_size = this->stage_size; | |
CMT_ASSUME(stage_size >= 2048); | |
CMT_ASSUME(stage_size % 2048 == 0); | |
radix4_pass(stage_size, 1, csize<width>, ctrue, cbool<splitin>, cbool<!is_even>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, in, twiddle); | |
} | |
}; | |
template <typename T, bool splitin, size_t size, bool inverse> | |
struct fft_final_stage_impl : dft_stage<T> | |
{ | |
fft_final_stage_impl(size_t) | |
{ | |
this->stage_size = size; | |
this->out_offset = size; | |
this->repeats = 4; | |
this->recursion = true; | |
this->data_size = align_up(sizeof(complex<T>) * size * 3 / 2, native_cache_alignment); | |
} | |
protected: | |
constexpr static size_t width = vector_width<T, cpu_t::native>; | |
constexpr static bool is_even = cometa::is_even(ilog2(size)); | |
constexpr static bool use_br2 = !is_even; | |
constexpr static bool aligned = false; | |
constexpr static bool prefetch = splitin; | |
virtual void do_initialize(size_t total_size) override final | |
{ | |
complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
size_t stage_size = this->stage_size; | |
while (stage_size > 4) | |
{ | |
initialize_twiddles<T, width>(twiddle, stage_size, total_size, true); | |
stage_size /= 4; | |
} | |
} | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final | |
{ | |
constexpr bool is_double = sizeof(T) == 8; | |
constexpr size_t final_size = is_even ? (is_double ? 4 : 16) : (is_double ? 8 : 32); | |
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
final_pass(csize<final_size>, out, in, twiddle); | |
} | |
KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(512, 1, csize<width>, ctrue, cbool<splitin>, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, in, twiddle); | |
radix4_pass(128, 4, csize<width>, ctrue, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(32, 16, csize<width>, cfalse, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(csize<8>, 64, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(512, 1, csize<width>, ctrue, cbool<splitin>, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, in, twiddle); | |
radix4_pass(128, 4, csize<width>, cfalse, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(csize<32>, 16, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(1024, 1, csize<width>, ctrue, cbool<splitin>, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, in, twiddle); | |
radix4_pass(256, 4, csize<width>, ctrue, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(64, 16, csize<width>, ctrue, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(16, 64, csize<width>, cfalse, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(csize<4>, 256, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
KFR_INTRIN void final_pass(csize_t<16>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(1024, 1, csize<width>, ctrue, cbool<splitin>, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, in, twiddle); | |
radix4_pass(256, 4, csize<width>, ctrue, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(64, 16, csize<width>, cfalse, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(csize<16>, 64, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
}; | |
template <typename T, bool is_even> | |
struct fft_reorder_stage_impl : dft_stage<T> | |
{ | |
fft_reorder_stage_impl(size_t stage_size) | |
{ | |
this->stage_size = stage_size; | |
log2n = ilog2(stage_size); | |
this->data_size = 0; | |
} | |
protected: | |
size_t log2n; | |
virtual void do_initialize(size_t) override final {} | |
virtual void do_execute(complex<T>* out, const complex<T>*, u8* /*temp*/) override final | |
{ | |
fft_reorder(out, log2n, cbool<!is_even>); | |
} | |
}; | |
template <typename T, size_t log2n, bool inverse> | |
struct fft_specialization; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 1, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) {} | |
protected: | |
constexpr static bool aligned = false; | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final | |
{ | |
cvec<T, 1> a0, a1; | |
split(cread<2, aligned>(in), a0, a1); | |
cwrite<2, aligned>(out, concat(a0 + a1, a0 - a1)); | |
} | |
}; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 2, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) {} | |
protected: | |
constexpr static bool aligned = false; | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final | |
{ | |
cvec<T, 1> a0, a1, a2, a3; | |
split(cread<4>(in), a0, a1, a2, a3); | |
butterfly(cbool<inverse>, a0, a1, a2, a3, a0, a1, a2, a3); | |
cwrite<4>(out, concat(a0, a1, a2, a3)); | |
} | |
}; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 3, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) {} | |
protected: | |
constexpr static bool aligned = false; | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final | |
{ | |
cvec<T, 8> v8 = cread<8, aligned>(in); | |
butterfly8<inverse>(v8); | |
cwrite<8, aligned>(out, v8); | |
} | |
}; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 4, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) {} | |
protected: | |
constexpr static bool aligned = false; | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final | |
{ | |
cvec<T, 16> v16 = cread<16, aligned>(in); | |
butterfly16<inverse>(v16); | |
cwrite<16, aligned>(out, v16); | |
} | |
}; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 5, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) {} | |
protected: | |
constexpr static bool aligned = false; | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final | |
{ | |
cvec<T, 32> v32 = cread<32, aligned>(in); | |
butterfly32<inverse>(v32); | |
cwrite<32, aligned>(out, v32); | |
} | |
}; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 6, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) {} | |
protected: | |
constexpr static bool aligned = false; | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final | |
{ | |
butterfly64(cbool<inverse>, cbool<aligned>, out, in); | |
} | |
}; | |
template <typename T, bool inverse> | |
struct fft_specialization<T, 7, inverse> : dft_stage<T> | |
{ | |
fft_specialization(size_t) | |
{ | |
this->stage_size = 128; | |
this->data_size = align_up(sizeof(complex<T>) * 128 * 3 / 2, native_cache_alignment); | |
} | |
protected: | |
constexpr static bool aligned = false; | |
constexpr static size_t width = vector_width<T, cpu_t::native>; | |
constexpr static bool use_br2 = true; | |
constexpr static bool prefetch = false; | |
constexpr static bool is_double = sizeof(T) == 8; | |
constexpr static size_t final_size = is_double ? 8 : 32; | |
constexpr static size_t split_format = final_size == 8; | |
virtual void do_initialize(size_t total_size) override final | |
{ | |
complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
initialize_twiddles<T, width>(twiddle, 128, total_size, split_format); | |
initialize_twiddles<T, width>(twiddle, 32, total_size, split_format); | |
initialize_twiddles<T, width>(twiddle, 8, total_size, split_format); | |
} | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final | |
{ | |
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
final_pass(csize<final_size>, out, in, twiddle); | |
fft_reorder(out, csize<7>); | |
} | |
KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(128, 1, csize<width>, ctrue, cfalse, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, in, twiddle); | |
radix4_pass(32, 4, csize<width>, cfalse, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(csize<8>, 16, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(128, 1, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, in, twiddle); | |
radix4_pass(csize<32>, 4, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
}; | |
template <bool inverse> | |
struct fft_specialization<float, 8, inverse> : dft_stage<float> | |
{ | |
fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; } | |
protected: | |
virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final | |
{ | |
complex<float>* scratch = ptr_cast<complex<float>>(temp); | |
if (out == in) | |
{ | |
butterfly16_multi_flip<0, inverse>(scratch, out); | |
butterfly16_multi_flip<1, inverse>(scratch, out); | |
butterfly16_multi_flip<2, inverse>(scratch, out); | |
butterfly16_multi_flip<3, inverse>(scratch, out); | |
butterfly16_multi_natural<0, inverse>(out, scratch); | |
butterfly16_multi_natural<1, inverse>(out, scratch); | |
butterfly16_multi_natural<2, inverse>(out, scratch); | |
butterfly16_multi_natural<3, inverse>(out, scratch); | |
} | |
else | |
{ | |
butterfly16_multi_flip<0, inverse>(out, in); | |
butterfly16_multi_flip<1, inverse>(out, in); | |
butterfly16_multi_flip<2, inverse>(out, in); | |
butterfly16_multi_flip<3, inverse>(out, in); | |
butterfly16_multi_natural<0, inverse>(out, out); | |
butterfly16_multi_natural<1, inverse>(out, out); | |
butterfly16_multi_natural<2, inverse>(out, out); | |
butterfly16_multi_natural<3, inverse>(out, out); | |
} | |
} | |
}; | |
template <bool inverse> | |
struct fft_specialization<double, 8, inverse> : dft_stage<double> | |
{ | |
using T = double; | |
fft_specialization(size_t) | |
{ | |
this->stage_size = 256; | |
this->data_size = align_up(sizeof(complex<T>) * 256 * 3 / 2, native_cache_alignment); | |
} | |
protected: | |
constexpr static bool aligned = false; | |
constexpr static size_t width = vector_width<T, cpu_t::native>; | |
constexpr static bool use_br2 = false; | |
constexpr static bool prefetch = false; | |
constexpr static size_t split_format = true; | |
virtual void do_initialize(size_t total_size) override final | |
{ | |
complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
initialize_twiddles<T, width>(twiddle, 256, total_size, split_format); | |
initialize_twiddles<T, width>(twiddle, 64, total_size, split_format); | |
initialize_twiddles<T, width>(twiddle, 16, total_size, split_format); | |
} | |
virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final | |
{ | |
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); | |
final_pass(csize<4>, out, in, twiddle); | |
fft_reorder(out, csize<8>); | |
} | |
KFR_INTRIN void final_pass(csize_t<4>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle) | |
{ | |
radix4_pass(256, 1, csize<width>, ctrue, cfalse, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, in, twiddle); | |
radix4_pass(64, 4, csize<width>, ctrue, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(16, 16, csize<width>, cfalse, ctrue, cbool<use_br2>, cbool<prefetch>, cbool<inverse>, | |
cbool<aligned>, out, out, twiddle); | |
radix4_pass(csize<4>, 64, csize<width>, cfalse, cfalse, cbool<use_br2>, cbool<prefetch>, | |
cbool<inverse>, cbool<aligned>, out, out, twiddle); | |
} | |
}; | |
template <typename T, bool splitin, bool is_even> | |
struct fft_stage_impl_t | |
{ | |
template <bool inverse> | |
using type = internal::fft_stage_impl<T, splitin, is_even, inverse>; | |
}; | |
template <typename T, bool splitin, size_t size> | |
struct fft_final_stage_impl_t | |
{ | |
template <bool inverse> | |
using type = internal::fft_final_stage_impl<T, splitin, size, inverse>; | |
}; | |
template <typename T, bool is_even> | |
struct fft_reorder_stage_impl_t | |
{ | |
template <bool> | |
using type = internal::fft_reorder_stage_impl<T, is_even>; | |
}; | |
template <typename T, size_t log2n, bool aligned> | |
struct fft_specialization_t | |
{ | |
template <bool inverse> | |
using type = internal::fft_specialization<T, log2n, inverse>; | |
}; | |
} | |
namespace dft_type | |
{ | |
constexpr cbools_t<true, true> both{}; | |
constexpr cbools_t<true, false> direct{}; | |
constexpr cbools_t<false, true> inverse{}; | |
} | |
template <typename T> | |
struct dft_plan | |
{ | |
using dft_stage_ptr = std::unique_ptr<dft_stage<T>>; | |
size_t size; | |
size_t temp_size; | |
template <bool direct = true, bool inverse = true> | |
dft_plan(size_t size, cbools_t<direct, inverse> type = dft_type::both) | |
: size(size), temp_size(0), data_size(0) | |
{ | |
if (is_poweroftwo(size)) | |
{ | |
const size_t log2n = ilog2(size); | |
cswitch(csizes<1, 2, 3, 4, 5, 6, 7, 8>, log2n, | |
[&](auto log2n) { | |
add_stage<internal::fft_specialization_t<T, val_of(decltype(log2n)()), | |
false>::template type>(size, type); | |
}, | |
[&]() { | |
cswitch(cfalse_true, is_even(log2n), [&](auto is_even) { | |
make_fft(size, type, is_even, ctrue); | |
add_stage<internal::fft_reorder_stage_impl_t< | |
T, val_of(decltype(is_even)())>::template type>(size, type); | |
}); | |
}); | |
initialize(type); | |
} | |
} | |
KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp, bool inverse = false) const | |
{ | |
if (inverse) | |
execute_dft(ctrue, out, in, temp); | |
else | |
execute_dft(cfalse, out, in, temp); | |
} | |
template <bool inverse> | |
KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp, cbool_t<inverse> inv) const | |
{ | |
execute_dft(inv, out, in, temp); | |
} | |
template <size_t Tag1, size_t Tag2, size_t Tag3> | |
KFR_INTRIN void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, | |
univector<u8, Tag3>& temp, bool inverse = false) const | |
{ | |
if (inverse) | |
execute_dft(ctrue, out.data(), in.data(), temp.data()); | |
else | |
execute_dft(cfalse, out.data(), in.data(), temp.data()); | |
} | |
template <bool inverse, size_t Tag1, size_t Tag2, size_t Tag3> | |
KFR_INTRIN void execute(univector<complex<T>, Tag1>& out, const univector<complex<T>, Tag2>& in, | |
univector<u8, Tag3>& temp, cbool_t<inverse> inv) const | |
{ | |
execute_dft(inv, out.data(), in.data(), temp.data()); | |
} | |
private: | |
autofree<u8> data; | |
size_t data_size; | |
std::vector<dft_stage_ptr> stages[2]; | |
template <template <bool inverse> class Stage> | |
void add_stage(size_t stage_size, cbools_t<true, true>) | |
{ | |
dft_stage<T>* direct_stage = new Stage<false>(stage_size); | |
direct_stage->name = type_name<decltype(*direct_stage)>(); | |
dft_stage<T>* inverse_stage = new Stage<true>(stage_size); | |
inverse_stage->name = type_name<decltype(*inverse_stage)>(); | |
this->data_size += direct_stage->data_size; | |
this->temp_size += direct_stage->temp_size; | |
stages[0].push_back(dft_stage_ptr(direct_stage)); | |
stages[1].push_back(dft_stage_ptr(inverse_stage)); | |
} | |
template <template <bool inverse> class Stage> | |
void add_stage(size_t stage_size, cbools_t<true, false>) | |
{ | |
dft_stage<T>* direct_stage = new Stage<false>(stage_size); | |
direct_stage->name = type_name<decltype(*direct_stage)>(); | |
this->data_size += direct_stage->data_size; | |
this->temp_size += direct_stage->temp_size; | |
stages[0].push_back(dft_stage_ptr(direct_stage)); | |
} | |
template <template <bool inverse> class Stage> | |
void add_stage(size_t stage_size, cbools_t<false, true>) | |
{ | |
dft_stage<T>* inverse_stage = new Stage<true>(stage_size); | |
inverse_stage->name = type_name<decltype(*inverse_stage)>(); | |
this->data_size += inverse_stage->data_size; | |
this->temp_size += inverse_stage->temp_size; | |
stages[1].push_back(dft_stage_ptr(inverse_stage)); | |
} | |
template <bool direct, bool inverse, bool is_even, bool first> | |
void make_fft(size_t stage_size, cbools_t<direct, inverse> type, cbool_t<is_even>, cbool_t<first>) | |
{ | |
constexpr size_t final_size = is_even ? 1024 : 512; | |
using fft_stage_impl_t = internal::fft_stage_impl_t<T, !first, is_even>; | |
using fft_final_stage_impl_t = internal::fft_final_stage_impl_t<T, !first, final_size>; | |
if (stage_size >= 2048) | |
{ | |
add_stage<fft_stage_impl_t::template type>(stage_size, type); | |
make_fft(stage_size / 4, cbools<direct, inverse>, cbool<is_even>, cfalse); | |
} | |
else | |
{ | |
add_stage<fft_final_stage_impl_t::template type>(final_size, type); | |
} | |
} | |
template <bool direct, bool inverse> | |
void initialize(cbools_t<direct, inverse>) | |
{ | |
data = autofree<u8>(data_size); | |
if (direct) | |
{ | |
size_t offset = 0; | |
for (dft_stage_ptr& stage : stages[0]) | |
{ | |
stage->data = data.data() + offset; | |
stage->initialize(this->size); | |
offset += stage->data_size; | |
} | |
} | |
if (inverse) | |
{ | |
size_t offset = 0; | |
for (dft_stage_ptr& stage : stages[1]) | |
{ | |
stage->data = data.data() + offset; | |
if (!direct) | |
stage->initialize(this->size); | |
offset += stage->data_size; | |
} | |
} | |
} | |
template <bool inverse> | |
KFR_INTRIN void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const | |
{ | |
size_t stack[32] = { 0 }; | |
const size_t count = stages[inverse].size(); | |
for (size_t depth = 0; depth < count;) | |
{ | |
if (stages[inverse][depth]->recursion) | |
{ | |
complex<T>* rout = out; | |
const complex<T>* rin = in; | |
size_t rdepth = depth; | |
size_t maxdepth = depth; | |
do | |
{ | |
if (stack[rdepth] == stages[inverse][rdepth]->repeats) | |
{ | |
stack[rdepth] = 0; | |
rdepth--; | |
} | |
else | |
{ | |
stages[inverse][rdepth]->execute(rout, rin, temp); | |
rout += stages[inverse][rdepth]->out_offset; | |
rin = rout; | |
stack[rdepth]++; | |
if (rdepth < count - 1 && stages[inverse][rdepth + 1]->recursion) | |
rdepth++; | |
else | |
maxdepth = rdepth; | |
} | |
} while (rdepth != depth); | |
depth = maxdepth + 1; | |
} | |
else | |
{ | |
stages[inverse][depth]->execute(out, in, temp); | |
depth++; | |
} | |
in = out; | |
} | |
} | |
}; | |
} | |
#pragma clang diagnostic pop |