Skip to content

Commit

Permalink
Reference
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Aug 21, 2023
1 parent b1ac96a commit 632e9e1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 53 deletions.
13 changes: 6 additions & 7 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2942,17 +2942,18 @@ static T get_scalar_value_from_tensor(const ONNX_NAMESPACE::TensorProto* t) {
}
}

static const char* DFT_ver20_doc = R"DOC(Computes the discrete Fourier transform of input.)DOC";
static const char* DFT_ver20_doc = R"DOC(Computes the discrete Fourier transform of the input.)DOC";

ONNX_OPERATOR_SET_SCHEMA(
DFT,
20,
OpSchema()
.SetDoc(DFT_ver20_doc)
.Attr(
// TODO(justinchuby): Double check how conjugate symmetry is specified
"onesided",
"If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because "
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*. "
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m, w] = X[m,n_fft-w]*. "
"Note if the input or window tensors are complex, then onesided output is not possible. "
"Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT). "
"When invoked with real or complex valued input, the default value is 0. "
Expand All @@ -2969,8 +2970,6 @@ ONNX_OPERATOR_SET_SCHEMA(
"input",
"For real input, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[signal_dimN][1]. "
"For complex input, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"The first dimension is the batch dimension. "
"The following N dimensions correspond to the signal's dimensions. "
"The final dimension represents the real and imaginary parts of the value in that order.",
"T1",
OpSchema::Single,
Expand All @@ -2980,7 +2979,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
1,
"axis",
"The axis on which to perform the DFT. By default this value is set to 0.",
"The axis on which to perform the DFT. By default this value is set to `-1`.",
"tensor(int64)",
OpSchema::Optional,
true,
Expand All @@ -2989,7 +2988,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
2,
"dft_length",
"The length of the signal."
"The length of the signal. "
"If greater than the axis dimension, the signal will be zero-padded up to dft_length. "
"If less than the axis dimension, only the first dft_length values will be used as the signal. "
"It's an optional value. ",
Expand All @@ -3001,7 +3000,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(
0,
"output",
"The Fourier Transform of the input vector."
"The Fourier Transform of the input vector. "
"If onesided is 0, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=0 and onesided is 1, the following shape is expected: [floor(signal_dim0/2)+1][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=1 and onesided is 1, the following shape is expected: [signal_dim0][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. "
Expand Down
4 changes: 2 additions & 2 deletions onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2999,7 +2999,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
1,
"dft_length",
"The length of the signal."
"The length of the signal. "
"If greater than the axis dimension, the signal will be zero-padded up to dft_length. "
"If less than the axis dimension, only the first dft_length values will be used as the signal. "
"It's an optional value. ",
Expand All @@ -3011,7 +3011,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(
0,
"output",
"The Fourier Transform of the input vector."
"The Fourier Transform of the input vector. "
"If onesided is 0, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=1 and onesided is 1, the following shape is expected: [batch_idx][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. "
"If axis=2 and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][floor(signal_dim2/2)+1]...[signal_dimN][2]. "
Expand Down
98 changes: 54 additions & 44 deletions onnx/reference/ops/op_dft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
# pylint: disable=R0913,W0221

from typing import Sequence

Expand All @@ -11,36 +10,33 @@


def _fft(x: np.ndarray, fft_length: Sequence[int], axis: int) -> np.ndarray:
if fft_length is None:
fft_length = [x.shape[axis]]
try:
ft = np.fft.fft(x, fft_length[0], axis=axis)
except TypeError:
# numpy 1.16.6, an array cannot be a key in the dictionary
# fixed in numpy 1.21.5.
ft = np.fft.fft(x, int(fft_length[0]), axis=axis)

r = np.real(ft)
i = np.imag(ft)
merged = np.vstack([r[np.newaxis, ...], i[np.newaxis, ...]])
assert fft_length is not None

transformed = np.fft.fft(x, fft_length[0], axis=axis)
real_frequencies = np.real(transformed)
imaginary_frequencies = np.imag(transformed)
# TODO(justinchuby): Just concat on the last axis and remove transpose
merged = np.vstack(
[real_frequencies[np.newaxis, ...], imaginary_frequencies[np.newaxis, ...]]
)
perm = np.arange(len(merged.shape))
perm[:-1] = perm[1:]
perm[-1] = 0
tr = np.transpose(merged, list(perm))
if tr.shape[-1] != 2:
transposed = np.transpose(merged, list(perm))
if transposed.shape[-1] != 2:
raise RuntimeError(
f"Unexpected shape {tr.shape}, x.shape={x.shape} "
f"Unexpected shape {transposed.shape}, x.shape={x.shape} "
f"fft_length={fft_length}."
)
return tr
return transposed


def _cfft(
x: np.ndarray,
fft_length: Sequence[int],
axis: int,
onesided: bool = False,
normalize: bool = False,
onesided: bool,
normalize: bool,
) -> np.ndarray:
if x.shape[-1] == 1:
tmp = x
Expand All @@ -51,43 +47,46 @@ def _cfft(
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
tmp = real + 1j * imag
c = np.squeeze(tmp, -1)
res = _fft(c, fft_length, axis=axis)
complex_signals = np.squeeze(tmp, -1)
result = _fft(complex_signals, fft_length, axis=axis)
if onesided:
slices = [slice(0, a) for a in res.shape]
slices[axis] = slice(0, res.shape[axis] // 2 + 1)
res = res[tuple(slices)] # type: ignore
slices = [slice(0, a) for a in result.shape]
slices[axis] = slice(0, result.shape[axis] // 2 + 1)
result = result[tuple(slices)] # type: ignore
if normalize:
if len(fft_length) == 1:
res /= fft_length[0]
result /= fft_length[0]
else:
raise NotImplementedError(
f"normalize=True not implemented for fft_length={fft_length}."
)
return res
return result


def _ifft(
x: np.ndarray, fft_length: Sequence[int], axis: int = -1, onesided: bool = False
x: np.ndarray, fft_length: Sequence[int], axis: int, onesided: bool
) -> np.ndarray:
ft = np.fft.ifft(x, fft_length[0], axis=axis)
r = np.real(ft)
i = np.imag(ft)
merged = np.vstack([r[np.newaxis, ...], i[np.newaxis, ...]])
signals = np.fft.ifft(x, fft_length[0], axis=axis)
real_signals = np.real(signals)
imaginary_signals = np.imag(signals)
# TODO(justinchuby): Just concat on the last axis and remove transpose
merged = np.vstack(
[real_signals[np.newaxis, ...], imaginary_signals[np.newaxis, ...]]
)
perm = np.arange(len(merged.shape))
perm[:-1] = perm[1:]
perm[-1] = 0
tr = np.transpose(merged, list(perm))
if tr.shape[-1] != 2:
transposed = np.transpose(merged, list(perm))
if transposed.shape[-1] != 2:
raise RuntimeError(
f"Unexpected shape {tr.shape}, x.shape={x.shape} "
f"Unexpected shape {transposed.shape}, x.shape={x.shape} "
f"fft_length={fft_length}."
)
if onesided:
slices = [slice(a) for a in tr.shape]
slices[axis] = slice(0, tr.shape[axis] // 2 + 1)
return tr[tuple(slices)] # type: ignore
return tr
slices = [slice(a) for a in transposed.shape]
slices[axis] = slice(0, transposed.shape[axis] // 2 + 1)
return transposed[tuple(slices)]
return transposed


def _cifft(
Expand All @@ -102,16 +101,27 @@ def _cifft(
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
tmp = real + 1j * imag
c = np.squeeze(tmp, -1)
return _ifft(c, fft_length, axis=axis, onesided=onesided)
complex_signals = np.squeeze(tmp, -1)
return _ifft(complex_signals, fft_length, axis=axis, onesided=onesided)


class DFT_17(OpRun):
def _run(self, x, dft_length: Sequence[int] | None = None, axis=1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
if dft_length is None:
dft_length = (x.shape[axis],)
if inverse: # type: ignore
res = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
res = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (res.astype(x.dtype),)


class DFT(OpRun):
def _run(self, x, dft_length=None, axis=None, inverse=None, onesided=None): # type: ignore
class DFT_20(OpRun):
def _run(self, x, axis: int = -1, dft_length: Sequence[int] | None = None, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
if dft_length is None:
dft_length = np.array([x.shape[axis]], dtype=np.int64)
dft_length = (x.shape[axis],)
if inverse: # type: ignore
res = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
res = _cfft(x, dft_length, axis=axis, onesided=onesided)
res = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (res.astype(x.dtype),)

0 comments on commit 632e9e1

Please sign in to comment.