Skip to content

Commit

Permalink
Create DFT-20
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Aug 21, 2023
1 parent 9183bbb commit a9db6ad
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 29 deletions.
72 changes: 43 additions & 29 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2942,13 +2942,13 @@ static T get_scalar_value_from_tensor(const ONNX_NAMESPACE::TensorProto* t) {
}
}

static const char* DFT_ver17_doc = R"DOC(Computes the discrete Fourier transform of input.)DOC";
static const char* DFT_ver20_doc = R"DOC(Computes the discrete Fourier transform of input.)DOC";

ONNX_OPERATOR_SET_SCHEMA(
DFT,
17,
20,
OpSchema()
.SetDoc(DFT_ver17_doc)
.SetDoc(DFT_ver20_doc)
.Attr(
"onesided",
"If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because "
Expand All @@ -2959,11 +2959,6 @@ ONNX_OPERATOR_SET_SCHEMA(
"Values can be 0 or 1.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"axis",
"The axis on which to perform the DFT. By default this value is set to 1, which corresponds to the first dimension after the batch index.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"inverse",
"Whether to perform the inverse discrete fourier transform. By default this value is set to 0, which corresponds to false.",
Expand All @@ -2972,8 +2967,8 @@ ONNX_OPERATOR_SET_SCHEMA(
.Input(
0,
"input",
"For real input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][1]. "
"For complex input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"For real input, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[signal_dimN][1]. "
"For complex input, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"The first dimension is the batch dimension. "
"The following N dimentions correspond to the signal's dimensions. "
"The final dimension represents the real and imaginary parts of the value in that order.",
Expand All @@ -2984,6 +2979,15 @@ ONNX_OPERATOR_SET_SCHEMA(
OpSchema::NonDifferentiable)
.Input(
1,
"axis",
"The axis on which to perform the DFT. By default this value is set to 0.",
"tensor(int64)",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable)
.Input(
2,
"dft_length",
"The length of the signal."
"If greater than the axis dimension, the signal will be zero-padded up to dft_length. "
Expand All @@ -2998,39 +3002,52 @@ ONNX_OPERATOR_SET_SCHEMA(
0,
"output",
"The Fourier Transform of the input vector."
"If onesided is 0, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=1 and onesided is 1, the following shape is expected: [batch_idx][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. "
"If axis=2 and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][floor(signal_dim2/2)+1]...[signal_dimN][2]. "
"If axis=N and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2]. "
"If onesided is 0, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=0 and onesided is 1, the following shape is expected: [floor(signal_dim0/2)+1][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=1 and onesided is 1, the following shape is expected: [signal_dim0][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. "
"If axis=N and onesided is 1, the following shape is expected: [signal_dim0][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2]. "
"The signal_dim at the specified axis is equal to the dft_length.",
"T1")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to integers.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
bool inverse = static_cast<bool>(getAttribute(ctx, "inverse", 0));

const size_t input_arg_index = 0;
const size_t axis_arg_index = 1;
const size_t dft_length_arg_index = 2;
const size_t output_index = 0;

if (inverse && is_onesided) {
fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time");
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasInputShape(ctx, 0)) {
propagateElemTypeFromInputToOutput(ctx, input_arg_index, output_index);
if (!hasInputShape(ctx, input_arg_index)) {
// If no shape is available for the input, skip shape inference...
return;
}

// In general the output shape will match the input shape exactly
// So initialize the output shape with the input shape
auto& input_shape = getInputShape(ctx, 0);
auto& input_shape = getInputShape(ctx, input_arg_index);
ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape;

// Get the axis where the DFT will be performed.
auto axis = static_cast<int>(getAttribute(ctx, "axis", 1));
auto rank = input_shape.dim_size();
const TensorProto* axis_tensor = ctx.getInputData(axis_arg_index);
if (axis_tensor == nullptr) {
// Skip if axis is not known
return;
}
if (axis_tensor->dims_size() != 1) {
fail_shape_inference("axis input must be a scalar.");
}
const int64_t axis = get_scalar_value_from_tensor<int64_t>(axis_tensor);
const int64_t rank = input_shape.dim_size();

if (!(-rank <= axis && axis < rank)) {
fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank);
Expand All @@ -3040,24 +3057,21 @@ ONNX_OPERATOR_SET_SCHEMA(

// If dft_length is specified, then we should honor the shape.
// Set the output dimension to match the dft_length on the axis.
// If onesided this will be adjusted later on...
const TensorProto* dft_length = nullptr;
if (ctx.hasInput(1)) {
dft_length = ctx.getInputData(1);
// If onesided this will be adjusted in the next block
if (ctx.hasInput(dft_length_arg_index)) {
// dft_length is provided
const TensorProto* dft_length = ctx.getInputData(dft_length_arg_index);
if (dft_length == nullptr) {
// If we cannot read the dft_length, we cannot infer shape
// return...
return;
}
}

if (nullptr != dft_length) {
if (dft_length->dims_size() != 0) {
fail_shape_inference("dft_length input must be a scalar.");
}
auto dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value);
}

// When DFT is onesided, the output shape is half the size of the input shape
// along the specified axis.
if (is_onesided) {
Expand All @@ -3080,7 +3094,7 @@ ONNX_OPERATOR_SET_SCHEMA(
auto dim_size = static_cast<int64_t>(result_shape_proto.dim_size());
result_shape_proto.mutable_dim(static_cast<int>(dim_size - 1))->set_dim_value(2);

updateOutputShape(ctx, 0, result_shape_proto);
updateOutputShape(ctx, output_index, result_shape_proto);
}));

std::function<void(OpSchema&)> CosineSumWindowOpDocGenerator(const char* name) {
Expand Down
142 changes: 142 additions & 0 deletions onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2955,4 +2955,146 @@ ONNX_OPERATOR_SET_SCHEMA(
"tensor(int64)"},
"Constrain input and output types to float/int tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));

static const char* DFT_ver17_doc = R"DOC(Computes the discrete Fourier transform of input.)DOC";

ONNX_OPERATOR_SET_SCHEMA(
DFT,
17,
OpSchema()
.SetDoc(DFT_ver17_doc)
.Attr(
"onesided",
"If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because "
"the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*. "
"Note if the input or window tensors are complex, then onesided output is not possible. "
"Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT). "
"When invoked with real or complex valued input, the default value is 0. "
"Values can be 0 or 1.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"axis",
"The axis on which to perform the DFT. By default this value is set to 1, which corresponds to the first dimension after the batch index.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"inverse",
"Whether to perform the inverse discrete fourier transform. By default this value is set to 0, which corresponds to false.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(
0,
"input",
"For real input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][1]. "
"For complex input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"The first dimension is the batch dimension. "
"The following N dimentions correspond to the signal's dimensions. "
"The final dimension represents the real and imaginary parts of the value in that order.",
"T1",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Input(
1,
"dft_length",
"The length of the signal."
"If greater than the axis dimension, the signal will be zero-padded up to dft_length. "
"If less than the axis dimension, only the first dft_length values will be used as the signal. "
"It's an optional value. ",
"T2",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable)
.Output(
0,
"output",
"The Fourier Transform of the input vector."
"If onesided is 0, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. "
"If axis=1 and onesided is 1, the following shape is expected: [batch_idx][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. "
"If axis=2 and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][floor(signal_dim2/2)+1]...[signal_dimN][2]. "
"If axis=N and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2]. "
"The signal_dim at the specified axis is equal to the dft_length.",
"T1")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
bool inverse = static_cast<bool>(getAttribute(ctx, "inverse", 0));

if (inverse && is_onesided) {
fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time");
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasInputShape(ctx, 0)) {
// If no shape is available for the input, skip shape inference...
return;
}

// In general the output shape will match the input shape exactly
// So initialize the output shape with the input shape
auto& input_shape = getInputShape(ctx, 0);
ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape;

// Get the axis where the DFT will be performed.
auto axis = static_cast<int>(getAttribute(ctx, "axis", 1));
auto rank = input_shape.dim_size();

if (!(-rank <= axis && axis < rank)) {
fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank);
}

auto axis_idx = (axis >= 0 ? axis : axis + rank);

// If dft_length is specified, then we should honor the shape.
// Set the output dimension to match the dft_length on the axis.
// If onesided this will be adjusted later on...
const TensorProto* dft_length = nullptr;
if (ctx.hasInput(1)) {
dft_length = ctx.getInputData(1);
if (dft_length == nullptr) {
// If we cannot read the dft_length, we cannot infer shape
// return...
return;
}
}

if (nullptr != dft_length) {
if (dft_length->dims_size() != 0) {
fail_shape_inference("dft_length input must be a scalar.");
}
auto dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value);
}
// When DFT is onesided, the output shape is half the size of the input shape
// along the specified axis.
if (is_onesided) {
auto axis_dimension = result_shape_proto.dim(axis_idx);
// We need to update the output shape dimension along the specified axis,
// but sometimes the dimension will be a free dimension or be otherwise unset.
// Only perform inference when a input dimension value exists.
if (axis_dimension.has_dim_value()) {
auto original_signal_size = axis_dimension.dim_value();
auto half_signal_size = (original_signal_size >> 1) + 1;
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size);
} else {
// Clear the value and param (which would otherwie be inherited from the input).
result_shape_proto.mutable_dim(axis_idx)->clear_dim_value();
result_shape_proto.mutable_dim(axis_idx)->clear_dim_param();
}
}

// Coerce the last dimension to 2.
auto dim_size = static_cast<int64_t>(result_shape_proto.dim_size());
result_shape_proto.mutable_dim(static_cast<int>(dim_size - 1))->set_dim_value(2);

updateOutputShape(ctx, 0, result_shape_proto);
}));

} // namespace ONNX_NAMESPACE

0 comments on commit a9db6ad

Please sign in to comment.