Skip to content

Commit

Permalink
add complex support to convolve_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
slarew committed Feb 20, 2020
1 parent 8a969d0 commit f3bc9bb
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 58 deletions.
45 changes: 45 additions & 0 deletions docs/docs/convolution.md
Expand Up @@ -68,3 +68,48 @@ reverb_RR.apply(tmp4, audio[1]);
audio[0] = tmp1 + tmp2;
audio[1] = tmp3 + tmp4;
```
# Implementation Details
The convolution filter efficiently computes the convolution of two signals.
The efficiency is achieved by employing the FFT and the circular convolution
theorem. The algorithm is a variant of the [overlap-add
method](https://en.wikipedia.org/wiki/Overlap%E2%80%93add_method). It works on
a fixed block size \(B\) for arbitrarily long input signals. Thus, the
convolution of a streaming input signal with a long FIR filter \(h[n]\) (where
the length of \(h[n]\) may exceed the block size \(B\)) is computed with a
fixed complexity \(O(B \log B)\).
More formally, the convolution filter computes \(y[n] = (x * h)[n]\) by
partitioning the input \(x\) and filter \(h\) into blocks and applies the
overlap-add method. Let \(x[n]\) be an input signal of arbitrary length. Often,
\(x[n]\) is a streaming input with unknown length. Let \(h[n]\) be an FIR
filter with \(M\) taps. The convolution filter works on a fixed block size
\(B=2^b\).
First, the input and filter are windowed and shifted to the origin to give the
\(k\)-th block input \(x_k[n] = x[n + kB] , n=\{0,1,\ldots,B-1\},\forall
k\in\mathbb{Z}\) and \(j\)-th block filter \(h_j[n] = h[n + jB] ,
n=\{0,1,\ldots,B-1\},j=\{0,1,\ldots,\lfloor M/B \rfloor\}\). The convolution
\(y_{k,j}[n] = (x_k * h_j)[n]\) is efficiently computed with length \(2B\) FFTs
as
\[
y_{k,j}[n] = \mathrm{IFFT}(\mathrm{FFT}(x_k[n])\cdot\mathrm{FFT}(h_j[n]))
.
\]
The overlap-add method sums the "overlap" from the previous block with the current block.
To complete the \(k\)-th block, the contribution of all blocks of the filter
are summed together to give
\[ y_{k}[n] = \sum_j y_{k-j,j}[n] . \]
The final convolution is then the sum of the shifted blocks
\[ y[n] = \sum_k y_{k}[n - kB] . \]
Note that \(y_k[n]\) is of length \(2B\) so its second half overlaps and adds
into the first half of the \(y_{k+1}[n]\) block.
## Maximum efficiency criterion
To avoid excess computation or maximize throughput, the convolution filter
should be given input samples in multiples of the block size \(B\). Otherwise,
the FFT of a block is computed twice as many times as would be necessary and
hence throughput is reduced.
2 changes: 2 additions & 0 deletions examples/CMakeLists.txt
Expand Up @@ -50,4 +50,6 @@ if (ENABLE_DFT)
target_link_libraries(dft kfr_multidft)
target_compile_definitions(dft PRIVATE -DKFR_DFT_MULTI=1)
endif ()
add_executable(ccv ccv.cpp)
target_link_libraries(ccv kfr kfr_dft use_arch)
endif ()
71 changes: 71 additions & 0 deletions examples/ccv.cpp
@@ -0,0 +1,71 @@
/*
* ccv, part of KFR (https://www.kfr.dev)
* Copyright (C) 2019 D Levin
* See LICENSE.txt for details
*/

// Complex convolution filter examples

#define CMT_BASETYPE_F32

#include <chrono>
#include <kfr/base.hpp>
#include <kfr/dft.hpp>
#include <kfr/dsp.hpp>

using namespace kfr;

int main()
{
println(library_version());

// low-pass filter
univector<fbase, 1023> taps127;
expression_pointer<fbase> kaiser = to_pointer(window_kaiser(taps127.size(), 3.0));
fir_lowpass(taps127, 0.2, kaiser, true);

// Create filters.
size_t const block_size = 256;
convolve_filter<complex<fbase>> conv_filter_complex(univector<complex<fbase>>(make_complex(taps127, zeros())),
block_size);
convolve_filter<fbase> conv_filter_real(taps127, block_size);

// Create noise to filter.
auto const size = 1024 * 100 + 33; // not a multiple of block_size
univector<complex<fbase>> cnoise =
make_complex(truncate(gen_random_range(random_bit_generator{ 1, 2, 3, 4 }, -1.f, +1.f), size),
truncate(gen_random_range(random_bit_generator{ 3, 4, 9, 8 }, -1.f, +1.f), size));
univector<fbase> noise =
truncate(gen_random_range(random_bit_generator{ 3, 4, 9, 8 }, -1.f, +1.f), size);

// Filter results.
univector<complex<fbase>> filtered_cnoise_ccv(size), filtered_cnoise_fir(size);
univector<fbase> filtered_noise_ccv(size), filtered_noise_fir(size);

// Complex filtering (time and compare).
auto tic = std::chrono::high_resolution_clock::now();
conv_filter_complex.apply(filtered_cnoise_ccv, cnoise);
auto toc = std::chrono::high_resolution_clock::now();
auto const ccv_time_complex = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
tic = toc;
filtered_cnoise_fir = kfr::fir(cnoise, taps127);
toc = std::chrono::high_resolution_clock::now();
auto const fir_time_complex = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
auto const cdiff = rms(cabs(filtered_cnoise_fir - filtered_cnoise_ccv));

// Real filtering (time and compare).
tic = std::chrono::high_resolution_clock::now();
conv_filter_real.apply(filtered_noise_ccv, noise);
toc = std::chrono::high_resolution_clock::now();
auto const ccv_time_real = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
tic = toc;
filtered_noise_fir = kfr::fir(noise, taps127);
toc = std::chrono::high_resolution_clock::now();
auto const fir_time_real = std::chrono::duration_cast<std::chrono::duration<float>>(toc - tic);
auto const diff = rms(filtered_noise_fir - filtered_noise_ccv);

println("complex: convolution_filter ", ccv_time_complex.count(), " fir ", fir_time_complex.count(), " diff=", cdiff);
println("real: convolution_filter ", ccv_time_real.count(), " fir ", fir_time_real.count(), " diff=", diff);

return 0;
}
38 changes: 29 additions & 9 deletions include/kfr/dft/convolution.hpp
Expand Up @@ -84,6 +84,9 @@ class convolve_filter : public filter<T>
explicit convolve_filter(size_t size, size_t block_size = 1024);
explicit convolve_filter(const univector_ref<const T>& data, size_t block_size = 1024);
void set_data(const univector_ref<const T>& data);
void reset() final;
/// Apply filter to multiples of returned block size for optimal processing efficiency.
size_t input_block_size() const { return block_size; }

protected:
void process_expression(T* dest, const expression_pointer<T>& src, size_t size) final
Expand All @@ -93,19 +96,36 @@ class convolve_filter : public filter<T>
}
void process_buffer(T* output, const T* input, size_t size) final;

const size_t size;
using ST = subtype<T>;
static constexpr auto real_fft = !std::is_same<T, complex<ST>>::value;
using plan_t = std::conditional_t<real_fft, dft_plan_real<T>, dft_plan<ST>>;

// Length of filter data.
size_t data_size;
// Size of block to process.
const size_t block_size;
const dft_plan_real<T> fft;
// FFT plan for circular convolution.
const plan_t fft;
// Temp storage for FFT.
univector<u8> temp;
std::vector<univector<complex<T>>> segments;
std::vector<univector<complex<T>>> ir_segments;
size_t input_position;
// History of input segments after fwd DFT. History is circular relative to position below.
std::vector<univector<complex<ST>>> segments;
// Index into segments of current block.
size_t position;
// Blocks of filter/data after fwd DFT.
std::vector<univector<complex<ST>>> ir_segments;
// Saved input for current block.
univector<T> saved_input;
univector<complex<T>> premul;
univector<complex<T>> cscratch;
univector<T> scratch;
// Index into saved_input for next input to begin.
size_t input_position;
// Pre-multiplied products of input history and delayed filter blocks.
univector<complex<ST>> premul;
// Scratch buffer for product of filter and input for processing by reverse DFT.
univector<complex<ST>> cscratch;
// Scratch buffers for input and output of fwd and rev DFTs.
univector<T> scratch1, scratch2;
// Overlap saved from previous block to add into current block.
univector<T> overlap;
size_t position;
};
} // namespace CMT_ARCH_NAME

Expand Down

0 comments on commit f3bc9bb

Please sign in to comment.