Skip to content

Commit

Permalink
Change default axis to -2
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Sep 21, 2023
1 parent e92256b commit 680059f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
20 changes: 10 additions & 10 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3048,9 +3048,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
2,
"axis",
"The axis as a scalar on which to perform the DFT. Default is `-1` (last axis). "
"Negative value means counting dimensions from the back. Accepted range is `[-r, r-1]` where `r = rank(input) - 1`. "
"The last dimension is for representing complex numbers and is thus not indexed.",
"The axis as a scalar on which to perform the DFT. Default is `-2` (last signal axis). "
"Negative value means counting dimensions from the back. Accepted range is $[-r, -2] \\cup [0, r-1]$ where `r = rank(input)`. "
"The last dimension is for representing complex numbers and is thus an invalid axis.",
"tensor(int64)",
OpSchema::Optional,
true,
Expand Down Expand Up @@ -3096,8 +3096,8 @@ ONNX_OPERATOR_SET_SCHEMA(
const TensorProto* axis_tensor = ctx.getInputData(axis_arg_index);
int64_t axis;
if (axis_tensor == nullptr) {
// axis is -1 by default
axis = -1;
// axis is -2 by default
axis = -2;
} else {
// TODO(justinchuby): Create invariance checking functions to ensure shapes and sizes
// to abstrct the following logic out.
Expand All @@ -3107,12 +3107,12 @@ ONNX_OPERATOR_SET_SCHEMA(
axis = defs::math::utils::GetScalarValueFromTensor<int64_t>(axis_tensor);
}
// The last dimension is the real and imaginary parts of the value.
const int64_t rank = input_shape.dim_size() - 1;
if (rank < 1) {
fail_shape_inference("input tensor must have rank >= 1, excluding the complex dimension.");
const int64_t rank = input_shape.dim_size();
if (rank < 2) {
fail_shape_inference("input tensor must have rank >= 2, including the complex dimension.");
}
if (!(-rank <= axis && axis < rank)) {
fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank);
if (!(-rank <= axis && axis != -1 && axis < rank - 1)) {
fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank, ". Valid values are '-rank <= axis && axis != -1 && axis < rank - 1'");
}

auto axis_idx = (axis >= 0 ? axis : axis + rank);
Expand Down
14 changes: 8 additions & 6 deletions onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2975,7 +2975,9 @@ ONNX_OPERATOR_SET_SCHEMA(
static_cast<int64_t>(0))
.Attr(
"axis",
"The axis on which to perform the DFT. By default this value is set to 1, which corresponds to the first dimension after the batch index.",
"The axis on which to perform the DFT. By default this value is set to 1, which corresponds to the first dimension after the batch index."
"Negative value means counting dimensions from the back. Accepted range is $[-r, -2] \\cup [0, r-1]$ where `r = rank(input)`. "
"The last dimension is for representing complex numbers and is thus an invalid axis.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
Expand Down Expand Up @@ -3045,12 +3047,12 @@ ONNX_OPERATOR_SET_SCHEMA(
// Get the axis where the DFT will be performed.
auto axis = static_cast<int>(getAttribute(ctx, "axis", 1));
// The last dimension is the real and imaginary parts of the value.
const int64_t rank = input_shape.dim_size() - 1;
if (rank < 1) {
fail_shape_inference("input tensor must have rank >= 1, excluding the complex dimension.");
const int64_t rank = input_shape.dim_size();
if (rank < 2) {
fail_shape_inference("input tensor must have rank >= 2, including the complex dimension.");
}
if (!(-rank <= axis && axis < rank)) {
fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank);
if (!(-rank <= axis && axis != -1 && axis < rank)) {
fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank, ". Valid values are '-rank <= axis && axis != -1 && axis < rank - 1'");
}

auto axis_idx = (axis >= 0 ? axis : axis + rank);
Expand Down
8 changes: 6 additions & 2 deletions onnx/reference/ops/op_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _ifft(x: np.ndarray, fft_length: int, axis: int, onesided: bool) -> np.ndarr


def _cifft(
x: np.ndarray, fft_length: int, axis: int = -1, onesided: bool = False
x: np.ndarray, fft_length: int, axis: int, onesided: bool = False
) -> np.ndarray:
if x.shape[-1] == 1:
frequencies = x
Expand All @@ -82,6 +82,8 @@ def _cifft(

class DFT_17(OpRun):
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = 1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
# Convert to positive axis
axis = axis % len(x.shape)
if dft_length is None:
dft_length = x.shape[axis]
if inverse: # type: ignore
Expand All @@ -92,7 +94,9 @@ def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = 1, inve


class DFT_20(OpRun):
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = -1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = -2, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
# Convert to positive axis
axis = axis % len(x.shape)
if dft_length is None:
dft_length = x.shape[axis]
if inverse: # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8518,11 +8518,11 @@ def test_blackmanwindow(self):
("reals_axis_0", (3, 5, 10, 1), 0, 0, 0, (3, 5, 10, 2)),
("reals_axis_1", (3, 5, 10, 1), 1, 0, 0, (3, 5, 10, 2)),
("reals_axis_2", (3, 5, 10, 1), 2, 0, 0, (3, 5, 10, 2)),
("reals_axis_neg", (3, 5, 10, 1), -1, 0, 0, (3, 5, 10, 2)),
("reals_axis_neg", (3, 5, 10, 1), -2, 0, 0, (3, 5, 10, 2)),
("reals_axis_0_onesided", (3, 5, 10, 1), 0, 1, 0, (2, 5, 10, 2)),
("reals_axis_1_onesided", (3, 5, 10, 1), 1, 1, 0, (3, 3, 10, 2)),
("reals_axis_2_onesided", (3, 5, 10, 1), 2, 1, 0, (3, 5, 6, 2)),
("reals_axis_neg_onesided", (3, 5, 10, 1), -1, 1, 0, (3, 5, 6, 2)),
("reals_axis_neg_onesided", (3, 5, 10, 1), -2, 1, 0, (3, 5, 6, 2)),
("complex_default_axis", (2, 5, 2), None, None, None, (2, 5, 2)),
("complex_onesided", (2, 5, 2), 1, 1, None, (2, 3, 2)),
("real_inverse", (2, 5, 1), 1, None, 1, (2, 5, 2)),
Expand Down

0 comments on commit 680059f

Please sign in to comment.