-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
fft.py
84 lines (65 loc) · 2.48 KB
/
fft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from chainer import backend
from chainer import function_node
from chainer.utils import type_check
class FFT(function_node.FunctionNode):
"""Fast Fourier transform."""
def __init__(self, method):
self._method = method
def check_type_forward(self, in_types):
type_check._argname(in_types, ('real', 'imag'))
r_type, i_type = in_types
type_check.expect(
r_type.dtype.kind == 'f',
r_type.ndim > 0,
r_type.shape == i_type.shape,
r_type.dtype == i_type.dtype,
)
def forward(self, inputs):
xp = backend.get_array_module(*inputs)
real, imag = inputs
x = real + imag * 1j
y = getattr(xp.fft, self._method)(x)
real_y = y.real.astype(real.dtype, copy=False)
imag_y = y.imag.astype(imag.dtype, copy=False)
return real_y, imag_y
def backward(self, inputs, grads):
gr, gi = grads
xp = backend.get_array_module(*grads)
if gr is None:
gr = xp.zeros_like(gi.data)
if gi is None:
gi = xp.zeros_like(gr.data)
gxi, gxr = FFT(self._method).apply((gi, gr))
return gxr, gxi
def fft(x):
"""Fast Fourier transform.
Args:
x (tuple): ``(real, imag)`` where ``real`` is a
:class:`~chainer.Variable` or an :ref:`ndarray` storing the real
part and ``imag`` is a :class:`~chainer.Variable` or an
:ref:`ndarray` storing the imaginary part.
Returns:
tuple: Returns ``(ry, iy)`` where ``ry`` is the real part of
the result and ``iy`` is the imaginary part of the result.
.. note::
Currently this function supports a tuple as input. It will support a
complex numbers directly in the future.
"""
real, imag = x
return FFT('fft').apply((real, imag))
def ifft(x):
"""Inverse fast Fourier transform.
Args:
x (tuple): ``(real, imag)`` where ``real`` is a
:class:`~chainer.Variable` or an :ref:`ndarray` storing the real
part and ``imag`` is a :class:`~chainer.Variable` or an
:ref:`ndarray` storing the imaginary part.
Returns:
tuple: Returns ``(ry, iy)`` where ``ry`` is the real part of
the result and ``iy`` is the imaginary part of the result.
.. note::
Currently this function supports a tuple as input. It will support a
complex numbers directly in the future.
"""
real, imag = x
return FFT('ifft').apply((real, imag))