Skip to content

Commit

Permalink
Refactor ref implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Aug 22, 2023
1 parent b712886 commit 69d0f4e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 67 deletions.
2 changes: 1 addition & 1 deletion onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2967,7 +2967,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
2,
"dft_length",
"The length of the signal. "
"The length of the signal as a scalar. "
"If greater than the axis dimension, the signal will be zero-padded up to dft_length. "
"If less than the axis dimension, only the first dft_length values will be used as the signal. "
"It's an optional value. ",
Expand Down
2 changes: 1 addition & 1 deletion onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2999,7 +2999,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
1,
"dft_length",
"The length of the signal. "
"The length of the signal as a scalar. "
"If greater than the axis dimension, the signal will be zero-padded up to dft_length. "
"If less than the axis dimension, only the first dft_length values will be used as the signal. "
"It's an optional value. ",
Expand Down
105 changes: 40 additions & 65 deletions onnx/reference/ops/op_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,126 +2,101 @@

# SPDX-License-Identifier: Apache-2.0

from typing import Sequence

import numpy as np

from onnx.reference.op_run import OpRun


def _fft(x: np.ndarray, fft_length: Sequence[int], axis: int) -> np.ndarray:
assert fft_length is not None

transformed = np.fft.fft(x, fft_length[0], axis=axis)
def _fft(x: np.ndarray, fft_length: int, axis: int) -> np.ndarray:
"""Compute the FFT return the real representation of the complex result."""
transformed = np.fft.fft(x, n=fft_length, axis=axis)
real_frequencies = np.real(transformed)
imaginary_frequencies = np.imag(transformed)
# TODO(justinchuby): Just concat on the last axis and remove transpose
merged = np.vstack(
[real_frequencies[np.newaxis, ...], imaginary_frequencies[np.newaxis, ...]]
return np.concatenate(
(real_frequencies[..., np.newaxis], imaginary_frequencies[..., np.newaxis]),
axis=-1,
)
perm = np.arange(len(merged.shape))
perm[:-1] = perm[1:]
perm[-1] = 0
transposed = np.transpose(merged, list(perm))
if transposed.shape[-1] != 2:
raise RuntimeError(
f"Unexpected shape {transposed.shape}, x.shape={x.shape} "
f"fft_length={fft_length}."
)
return transposed


def _cfft(
x: np.ndarray,
fft_length: Sequence[int],
fft_length: int,
axis: int,
onesided: bool,
normalize: bool,
) -> np.ndarray:
if x.shape[-1] == 1:
tmp = x
# The input contains only the real part
signal = x
else:
# The input is a real representation of a complex signal
slices = [slice(0, x) for x in x.shape]
slices[-1] = slice(0, x.shape[-1], 2)
real = x[tuple(slices)]
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
tmp = real + 1j * imag
complex_signals = np.squeeze(tmp, -1)
signal = real + 1j * imag

complex_signals = np.squeeze(signal, -1)
result = _fft(complex_signals, fft_length, axis=axis)
# Post process the result based on arguments
if onesided:
slices = [slice(0, a) for a in result.shape]
slices[axis] = slice(0, result.shape[axis] // 2 + 1)
result = result[tuple(slices)] # type: ignore
result = result[tuple(slices)]
if normalize:
if len(fft_length) == 1:
result /= fft_length[0]
else:
raise NotImplementedError(
f"normalize=True not implemented for fft_length={fft_length}."
)
result /= fft_length
return result


def _ifft(
x: np.ndarray, fft_length: Sequence[int], axis: int, onesided: bool
) -> np.ndarray:
signals = np.fft.ifft(x, fft_length[0], axis=axis)
def _ifft(x: np.ndarray, fft_length: int, axis: int, onesided: bool) -> np.ndarray:
signals = np.fft.ifft(x, fft_length, axis=axis)
real_signals = np.real(signals)
imaginary_signals = np.imag(signals)
# TODO(justinchuby): Just concat on the last axis and remove transpose
merged = np.vstack(
[real_signals[np.newaxis, ...], imaginary_signals[np.newaxis, ...]]
merged = np.concatenate(
(real_signals[..., np.newaxis], imaginary_signals[..., np.newaxis]),
axis=-1,
)
perm = np.arange(len(merged.shape))
perm[:-1] = perm[1:]
perm[-1] = 0
transposed = np.transpose(merged, list(perm))
if transposed.shape[-1] != 2:
raise RuntimeError(
f"Unexpected shape {transposed.shape}, x.shape={x.shape} "
f"fft_length={fft_length}."
)
if onesided:
slices = [slice(a) for a in transposed.shape]
slices[axis] = slice(0, transposed.shape[axis] // 2 + 1)
return transposed[tuple(slices)]
return transposed
slices = [slice(a) for a in merged.shape]
slices[axis] = slice(0, merged.shape[axis] // 2 + 1)
return merged[tuple(slices)]
return merged


def _cifft(
x: np.ndarray, fft_length: Sequence[int], axis: int = -1, onesided: bool = False
x: np.ndarray, fft_length: int, axis: int = -1, onesided: bool = False
) -> np.ndarray:
if x.shape[-1] == 1:
tmp = x
frequencies = x
else:
slices = [slice(0, x) for x in x.shape]
slices[-1] = slice(0, x.shape[-1], 2)
real = x[tuple(slices)]
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
tmp = real + 1j * imag
complex_signals = np.squeeze(tmp, -1)
return _ifft(complex_signals, fft_length, axis=axis, onesided=onesided)
frequencies = real + 1j * imag
complex_frequencies = np.squeeze(frequencies, -1)
return _ifft(complex_frequencies, fft_length, axis=axis, onesided=onesided)


class DFT_17(OpRun):
def _run(self, x, dft_length: Sequence[int] | None = None, axis=1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
def _run(self, x: np.ndarray, dft_length: int | None = None, axis=1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
if dft_length is None:
dft_length = (x.shape[axis],)
dft_length = x.shape[axis]
if inverse: # type: ignore
res = _cifft(x, dft_length, axis=axis, onesided=onesided)
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
res = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (res.astype(x.dtype),)
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (result.astype(x.dtype),)


class DFT_20(OpRun):
def _run(self, x, axis: int = -1, dft_length: Sequence[int] | None = None, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
def _run(self, x: np.ndarray, axis: int = -1, dft_length: int | None = None, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
if dft_length is None:
dft_length = (x.shape[axis],)
dft_length = x.shape[axis]
if inverse: # type: ignore
res = _cifft(x, dft_length, axis=axis, onesided=onesided)
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
res = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (res.astype(x.dtype),)
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (result.astype(x.dtype),)

0 comments on commit 69d0f4e

Please sign in to comment.