@@ -3,53 +3,66 @@ import nextPOT from 'next-power-of-two';
3
3
4
4
import { checkKernel } from './utils' ;
5
5
6
- export default function fftConvolution ( input , kernel , borderType = 'CONSTANT' ) {
7
- checkKernel ( kernel ) ;
8
- switch ( borderType ) {
9
- case 'CONSTANT' : {
10
- return fftConvolutionImpl ( input , kernel , false ) ;
11
- }
12
- case 'CUT' : {
13
- return fftConvolutionImpl ( input , kernel , true ) ;
14
- }
15
- default : {
16
- throw new Error ( `unexpected border type: ${ borderType } ` ) ;
6
+ export class FFTConvolution {
7
+ constructor ( size , kernel ) {
8
+ if ( ! Number . isInteger ( size ) || size < 1 ) {
9
+ throw new TypeError ( 'size must be a positive integer' ) ;
17
10
}
11
+ checkKernel ( kernel ) ;
12
+ this . kernelOffset = ( kernel . length - 1 ) / 2 ;
13
+ this . doubleOffset = 2 * this . kernelOffset ;
14
+ const resultLength = size + this . doubleOffset ;
15
+ this . fftLength = nextPOT ( Math . max ( resultLength , 2 ) ) ;
16
+ this . fft = new FFT ( this . fftLength ) ;
17
+ kernel = kernel . slice ( ) . reverse ( ) ;
18
+ const { output : fftKernel , input : result } = createPaddedFFt (
19
+ kernel ,
20
+ this . fft ,
21
+ this . fftLength
22
+ ) ;
23
+ this . fftKernel = fftKernel ;
24
+ this . ifftOutput = this . fft . createComplexArray ( ) ;
25
+ this . result = result ;
18
26
}
19
- }
20
27
21
- function fftConvolutionImpl ( input , kernel , cutBorder ) {
22
- const kernelOffset = ( kernel . length - 1 ) / 2 ;
23
- const doubleOffset = 2 * kernelOffset ;
24
- const resultLength = input . length + doubleOffset ;
25
- const fftLength = nextPOT ( resultLength ) ;
28
+ convolute ( input , borderType = 'CONSTANT' ) {
29
+ // if (input.length) // TODO CHECK SIZE
30
+ const { output : fftInput } = createPaddedFFt (
31
+ input ,
32
+ this . fft ,
33
+ this . fftLength
34
+ ) ;
26
35
27
- const fft = new FFT ( fftLength ) ;
28
-
29
- kernel = kernel . slice ( ) . reverse ( ) ;
30
- const { output : fftKernel , input : result } = createPaddedFFt (
31
- kernel ,
32
- fft ,
33
- fftLength
34
- ) ;
35
- const { output : fftInput } = createPaddedFFt ( input , fft , fftLength ) ;
36
+ for ( var i = 0 ; i < fftInput . length ; i += 2 ) {
37
+ const tmp =
38
+ fftInput [ i ] * this . fftKernel [ i ] -
39
+ fftInput [ i + 1 ] * this . fftKernel [ i + 1 ] ;
40
+ fftInput [ i + 1 ] =
41
+ fftInput [ i ] * this . fftKernel [ i + 1 ] +
42
+ fftInput [ i + 1 ] * this . fftKernel [ i ] ;
43
+ fftInput [ i ] = tmp ;
44
+ }
36
45
37
- for ( var i = 0 ; i < fftInput . length ; i += 2 ) {
38
- const tmp = fftInput [ i ] * fftKernel [ i ] - fftInput [ i + 1 ] * fftKernel [ i + 1 ] ;
39
- fftInput [ i + 1 ] =
40
- fftInput [ i ] * fftKernel [ i + 1 ] + fftInput [ i + 1 ] * fftKernel [ i ] ;
41
- fftInput [ i ] = tmp ;
42
- }
43
- const inverse = fftKernel ;
44
- fft . inverseTransform ( inverse , fftInput ) ;
45
- const r = fft . fromComplexArray ( inverse , result ) ;
46
- if ( cutBorder ) {
47
- return r . slice ( doubleOffset , input . length ) ;
48
- } else {
49
- return r . slice ( kernelOffset , kernelOffset + input . length ) ;
46
+ this . fft . inverseTransform ( this . ifftOutput , fftInput ) ;
47
+ const r = this . fft . fromComplexArray ( this . ifftOutput , this . result ) ;
48
+ switch ( borderType ) {
49
+ case 'CONSTANT' : {
50
+ return r . slice ( this . kernelOffset , this . kernelOffset + input . length ) ;
51
+ }
52
+ case 'CUT' : {
53
+ return r . slice ( this . doubleOffset , input . length ) ;
54
+ }
55
+ default : {
56
+ throw new Error ( `unexpected border type: ${ borderType } ` ) ;
57
+ }
58
+ }
50
59
}
51
60
}
52
61
62
+ export function fftConvolution ( input , kernel , borderType = 'CONSTANT' ) {
63
+ return new FFTConvolution ( input . length , kernel ) . convolute ( input , borderType ) ;
64
+ }
65
+
53
66
function createPaddedFFt ( data , fft , length ) {
54
67
const input = [ ] ;
55
68
let i = 0 ;
0 commit comments