Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Sep 21, 2023
1 parent 73bdff9 commit 5a84146
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 9 deletions.
3 changes: 2 additions & 1 deletion onnx/backend/test/case/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from typing_extensions import Literal

import onnx
from onnx.backend.test.case.test_case import TestCase
Expand Down Expand Up @@ -236,7 +237,7 @@ def _make_test_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto
# the latest opset vesion that supports before targeted opset version
def expect(
node_op: onnx.NodeProto,
inputs: Sequence[Union[np.ndarray, TensorProto]],
inputs: Sequence[Union[np.ndarray, TensorProto, Literal[""]]],
outputs: Sequence[Union[np.ndarray, TensorProto]],
name: str,
**kwargs: Any,
Expand Down
55 changes: 51 additions & 4 deletions onnx/backend/test/case/node/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,34 @@

class DFT(Base):
@staticmethod
def export() -> None:
def export_opset19() -> None:
node = onnx.helper.make_node("DFT", inputs=["x"], outputs=["y"], axis=1)
x = np.arange(0, 100).reshape(10, 10).astype(np.float32)
y = np.fft.fft(x, axis=0)

x = x.reshape(1, 10, 10, 1)
y = np.stack((y.real, y.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
expect(node, inputs=[x], outputs=[y], name="test_dft")
expect(
node,
inputs=[x],
outputs=[y],
name="test_dft_opset19",
opset_imports=[onnx.helper.make_opsetid("", 19)],
)

node = onnx.helper.make_node("DFT", inputs=["x"], outputs=["y"], axis=2)
x = np.arange(0, 100).reshape(10, 10).astype(np.float32)
y = np.fft.fft(x, axis=1)

x = x.reshape(1, 10, 10, 1)
y = np.stack((y.real, y.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
expect(node, inputs=[x], outputs=[y], name="test_dft_axis")
expect(
node,
inputs=[x],
outputs=[y],
name="test_dft_axis_opset19",
opset_imports=[onnx.helper.make_opsetid("", 19)],
)

node = onnx.helper.make_node(
"DFT", inputs=["x"], outputs=["y"], inverse=1, axis=1
Expand All @@ -39,4 +51,39 @@ def export() -> None:

x = np.stack((x.real, x.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
y = np.stack((y.real, y.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
expect(node, inputs=[x], outputs=[y], name="test_dft_inverse")
expect(
node,
inputs=[x],
outputs=[y],
name="test_dft_inverse_opset19",
opset_imports=[onnx.helper.make_opsetid("", 19)],
)

@staticmethod
def export() -> None:
node = onnx.helper.make_node("DFT", inputs=["x", "", "axis"], outputs=["y"])
x = np.arange(0, 100).reshape(10, 10).astype(np.float32)
axis = np.array(1, dtype=np.int64)
y = np.fft.fft(x, axis=0)

x = x.reshape(1, 10, 10, 1)
y = np.stack((y.real, y.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
expect(node, inputs=[x, "", axis], outputs=[y], name="test_dft")

node = onnx.helper.make_node("DFT", inputs=["x"], outputs=["y"])
x = np.arange(0, 100).reshape(10, 10).astype(np.float32)
axis = np.array(2, dtype=np.int64)
y = np.fft.fft(x, axis=1)

x = x.reshape(1, 10, 10, 1)
y = np.stack((y.real, y.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
expect(node, inputs=[x, "", axis], outputs=[y], name="test_dft_axis")

node = onnx.helper.make_node("DFT", inputs=["x"], outputs=["y"], inverse=1)
x = np.arange(0, 100, dtype=np.complex64).reshape(10, 10)
axis = np.array(1, dtype=np.int64)
y = np.fft.ifft(x, axis=0)

x = np.stack((x.real, x.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
y = np.stack((y.real, y.imag), axis=2).astype(np.float32).reshape(1, 10, 10, 2)
expect(node, inputs=[x, "", axis], outputs=[y], name="test_dft_inverse")
2 changes: 1 addition & 1 deletion onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2966,7 +2966,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Attr(
"onesided",
"If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because "
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*. "
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] =X[m, n_fft-w]*. "
"Note if the input or window tensors are complex, then onesided output is not possible. "
"Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT). "
"When invoked with real or complex valued input, the default value is 0. "
Expand Down
6 changes: 3 additions & 3 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8552,7 +8552,7 @@ def test_dft_reals(
if version < 20:
if axis is not None:
attributes["axis"] = axis
nodes = [make_node("DFT", ["input", ""], ["output"], **attributes)]
nodes = [make_node("DFT", ["input", ""], ["output"], **attributes)] # type: ignore[arg-type]
value_infos = []
else:
assert version >= 20
Expand All @@ -8564,12 +8564,12 @@ def test_dft_reals(
["axis"],
value=make_tensor("axis", TensorProto.INT64, (), (axis,)),
),
make_node("DFT", ["input", "", "axis"], ["output"], **attributes),
make_node("DFT", ["input", "", "axis"], ["output"], **attributes), # type: ignore[arg-type]
]
value_infos = [make_tensor_value_info("axis", TensorProto.INT64, ())]
else:
nodes = [
make_node("DFT", ["input", "", ""], ["output"], **attributes),
make_node("DFT", ["input", "", ""], ["output"], **attributes), # type: ignore[arg-type]
]
value_infos = []

Expand Down

0 comments on commit 5a84146

Please sign in to comment.