Skip to content

Commit

Permalink
Rotate axis order
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Sep 19, 2023
1 parent 373c478 commit f35f162
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3036,22 +3036,22 @@ ONNX_OPERATOR_SET_SCHEMA(
OpSchema::NonDifferentiable)
.Input(
1,
"axis",
"The axis as a scalar on which to perform the DFT. Default is `-1` (last axis). "
"Negative value means counting dimensions from the back. Accepted range is `[-r, r-1]` where `r = rank(input)-1`. "
"The last dimension is for representing complex numbers and is thus not indexed.",
"tensor(int64)",
"dft_length",
"The length of the signal as a scalar. "
"If greater than the axis dimension, the signal will be zero-padded up to `dft_length`. "
"If less than the axis dimension, only the first `dft_length` values will be used as the signal. ",
"T2",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable)
.Input(
2,
"dft_length",
"The length of the signal as a scalar. "
"If greater than the axis dimension, the signal will be zero-padded up to `dft_length`. "
"If less than the axis dimension, only the first `dft_length` values will be used as the signal. ",
"T2",
"axis",
"The axis as a scalar on which to perform the DFT. Default is `-1` (last axis). "
"Negative value means counting dimensions from the back. Accepted range is `[-r, r-1]` where `r = rank(input) - 1`. "
"The last dimension is for representing complex numbers and is thus not indexed.",
"tensor(int64)",
OpSchema::Optional,
true,
1,
Expand All @@ -3073,8 +3073,8 @@ ONNX_OPERATOR_SET_SCHEMA(
bool inverse = static_cast<bool>(getAttribute(ctx, "inverse", 0));

const size_t input_arg_index = 0;
const size_t axis_arg_index = 1;
const size_t dft_length_arg_index = 2;
const size_t dft_length_arg_index = 1;
const size_t axis_arg_index = 2;
const size_t output_index = 0;

if (inverse && is_onesided) {
Expand Down Expand Up @@ -3144,7 +3144,6 @@ ONNX_OPERATOR_SET_SCHEMA(
if (axis_dimension.has_dim_value()) {
auto original_signal_size = axis_dimension.dim_value();
auto half_signal_size = (original_signal_size >> 1) + 1;
printf("original_signal_size: %lld, half_signal_size: %lld\n", original_signal_size, half_signal_size);
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size);
} else {
// Clear the value and param (which would otherwie be inherited from the input).
Expand Down

0 comments on commit f35f162

Please sign in to comment.