Skip to content

Commit 4700d8b

Browse files
committed
feat: add FFTConvolution class
This will allow to reuse arrays for better performance.
1 parent 160f4cb commit 4700d8b

File tree

4 files changed

+67
-43
lines changed

4 files changed

+67
-43
lines changed

benchmark/convolution.js

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
'use strict';
22

3-
const { fftConvolution, directConvolution } = require('..');
3+
const { FFTConvolution, directConvolution } = require('..');
44

55
const tests = {
66
data: [128, 512, 2048, 4096, 16384, 65536, 262144, 1048576],
@@ -11,6 +11,11 @@ function test(dataLength, kernelLength) {
1111
const data = Array.from({ length: dataLength }, Math.random);
1212
const kernel = Array.from({ length: kernelLength }, Math.random);
1313

14+
const fft = new FFTConvolution(dataLength, kernel);
15+
const fftConvolution = (data) => {
16+
return fft.convolute(data);
17+
};
18+
1419
const fftResult = measure(data, kernel, fftConvolution);
1520
const directResult = measure(data, kernel, directConvolution);
1621

ml-convolution.d.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,11 @@ declare module 'ml-convolution' {
1414
kernel: ArrayLike<number>,
1515
borderType?: BorderType
1616
): number[];
17+
export class FFTConvolution {
18+
public constructor(size: number, kernel: ArrayLike<number>);
19+
public convolute(
20+
input: ArrayLike<number>,
21+
borderType?: BorderType
22+
): number[];
23+
}
1724
}

src/fftConvolution.js

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,66 @@ import nextPOT from 'next-power-of-two';
33

44
import { checkKernel } from './utils';
55

6-
export default function fftConvolution(input, kernel, borderType = 'CONSTANT') {
7-
checkKernel(kernel);
8-
switch (borderType) {
9-
case 'CONSTANT': {
10-
return fftConvolutionImpl(input, kernel, false);
11-
}
12-
case 'CUT': {
13-
return fftConvolutionImpl(input, kernel, true);
14-
}
15-
default: {
16-
throw new Error(`unexpected border type: ${borderType}`);
6+
export class FFTConvolution {
7+
constructor(size, kernel) {
8+
if (!Number.isInteger(size) || size < 1) {
9+
throw new TypeError('size must be a positive integer');
1710
}
11+
checkKernel(kernel);
12+
this.kernelOffset = (kernel.length - 1) / 2;
13+
this.doubleOffset = 2 * this.kernelOffset;
14+
const resultLength = size + this.doubleOffset;
15+
this.fftLength = nextPOT(Math.max(resultLength, 2));
16+
this.fft = new FFT(this.fftLength);
17+
kernel = kernel.slice().reverse();
18+
const { output: fftKernel, input: result } = createPaddedFFt(
19+
kernel,
20+
this.fft,
21+
this.fftLength
22+
);
23+
this.fftKernel = fftKernel;
24+
this.ifftOutput = this.fft.createComplexArray();
25+
this.result = result;
1826
}
19-
}
2027

21-
function fftConvolutionImpl(input, kernel, cutBorder) {
22-
const kernelOffset = (kernel.length - 1) / 2;
23-
const doubleOffset = 2 * kernelOffset;
24-
const resultLength = input.length + doubleOffset;
25-
const fftLength = nextPOT(resultLength);
28+
convolute(input, borderType = 'CONSTANT') {
29+
// if (input.length) // TODO CHECK SIZE
30+
const { output: fftInput } = createPaddedFFt(
31+
input,
32+
this.fft,
33+
this.fftLength
34+
);
2635

27-
const fft = new FFT(fftLength);
28-
29-
kernel = kernel.slice().reverse();
30-
const { output: fftKernel, input: result } = createPaddedFFt(
31-
kernel,
32-
fft,
33-
fftLength
34-
);
35-
const { output: fftInput } = createPaddedFFt(input, fft, fftLength);
36+
for (var i = 0; i < fftInput.length; i += 2) {
37+
const tmp =
38+
fftInput[i] * this.fftKernel[i] -
39+
fftInput[i + 1] * this.fftKernel[i + 1];
40+
fftInput[i + 1] =
41+
fftInput[i] * this.fftKernel[i + 1] +
42+
fftInput[i + 1] * this.fftKernel[i];
43+
fftInput[i] = tmp;
44+
}
3645

37-
for (var i = 0; i < fftInput.length; i += 2) {
38-
const tmp = fftInput[i] * fftKernel[i] - fftInput[i + 1] * fftKernel[i + 1];
39-
fftInput[i + 1] =
40-
fftInput[i] * fftKernel[i + 1] + fftInput[i + 1] * fftKernel[i];
41-
fftInput[i] = tmp;
42-
}
43-
const inverse = fftKernel;
44-
fft.inverseTransform(inverse, fftInput);
45-
const r = fft.fromComplexArray(inverse, result);
46-
if (cutBorder) {
47-
return r.slice(doubleOffset, input.length);
48-
} else {
49-
return r.slice(kernelOffset, kernelOffset + input.length);
46+
this.fft.inverseTransform(this.ifftOutput, fftInput);
47+
const r = this.fft.fromComplexArray(this.ifftOutput, this.result);
48+
switch (borderType) {
49+
case 'CONSTANT': {
50+
return r.slice(this.kernelOffset, this.kernelOffset + input.length);
51+
}
52+
case 'CUT': {
53+
return r.slice(this.doubleOffset, input.length);
54+
}
55+
default: {
56+
throw new Error(`unexpected border type: ${borderType}`);
57+
}
58+
}
5059
}
5160
}
5261

62+
export function fftConvolution(input, kernel, borderType = 'CONSTANT') {
63+
return new FFTConvolution(input.length, kernel).convolute(input, borderType);
64+
}
65+
5366
function createPaddedFFt(data, fft, length) {
5467
const input = [];
5568
let i = 0;

src/index.js

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import directConvolution from './directConvolution';
2-
import fftConvolution from './fftConvolution';
1+
export { default as directConvolution } from './directConvolution';
2+
export * from './fftConvolution';
33

4-
export { directConvolution, fftConvolution };
54
export const BorderType = {
65
CONSTANT: 'CONSTANT',
76
CUT: 'CUT'

0 commit comments

Comments
 (0)