Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Add OnnxToTorch lowering for Onnx.STFT op #3492

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12533,6 +12533,36 @@ def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [
let hasVerifier = 1;
}

def Torch_AtenStftOp : Torch_Op<"aten.stft", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$n_fft,
AnyTorchOptionalIntType:$hop_length,
AnyTorchOptionalIntType:$win_length,
AnyTorchOptionalTensorType:$window,
Torch_BoolType:$normalized,
AnyTorchOptionalBoolType:$onesided,
AnyTorchOptionalBoolType:$return_complex
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenStftOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
166 changes: 166 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3300,4 +3300,170 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
patterns.onOp(
"STFT", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// operands in order ->(signal, frameStep, window, frameLength*)
SmallVector<Value> operands;
int64_t onesided;
Torch::ValueTensorType resultType;

if (binder.tensorOperandsList(operands) ||
binder.s64IntegerAttr(onesided, "onesided", 1) ||
binder.tensorResultType(resultType))
return failure();

Value signal = operands[0];
Value frameStep = operands[1];
auto signalTy = cast<Torch::ValueTensorType>(signal.getType());
auto signalShape = signalTy.getSizes();
auto resultShape = resultType.getSizes();

// There are two possible cases for optional inputs frameLength and
// window, which are that either 4 operands will be passed with window
// being !torch.none, or three operands will be passed, with window
// present and frameLength absent. In the former case, we simply create
// a rectangular window consisting of ones, and in the latter, we set
// frameLength equal to the the inputShape[-2] or windowShape[0]
// depending upon whether window was present or not. Note that it is
// possible that both window and frameLength can be none, which would
// mean that either only two operands were passed, or, in case of three
// operands, window was passed in as none, and frameLength was absent.
Value window = nullptr, frameLength = nullptr;
bool windowIsNone = true, frameLengthIsNone = true;
if (operands.size() == 3) {
window = operands[2];
windowIsNone = isa<Torch::NoneType>(window.getType());
}
if (operands.size() == 4) {
window = operands[2];
frameLength = operands[3];
windowIsNone = isa<Torch::NoneType>(window.getType());
frameLengthIsNone = isa<Torch::NoneType>(frameLength.getType());
}

ArrayRef<int64_t> windowShape;
if (frameLengthIsNone) {
if (windowIsNone) {
frameLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
signalShape[signalShape.size() - 2]));
} else {
windowShape =
cast<Torch::ValueTensorType>(window.getType()).getSizes();
frameLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
}
}

Value frameLengthItem;
if (!frameLengthIsNone || windowIsNone) {
frameLengthItem =
getItemOp<Torch::IntType>(binder, rewriter, frameLength);
} else {
frameLengthItem = frameLength;
}
Value frameStepItem =
getItemOp<Torch::IntType>(binder, rewriter, frameStep);

if (windowIsNone) {
auto onesResultTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({-1}), signalTy.getDtype());

Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value sizes = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
SmallVector<Value>{frameLengthItem});
window = rewriter.create<Torch::AtenOnesOp>(
binder.getLoc(), onesResultTy, sizes, none, none, none, none);
}

FailureOr<Type> complexDtype;
if (signalTy.getDtype().isBF16()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support for bfloat16 type is unimplemented.");
vivekkhandelwal1 marked this conversation as resolved.
Show resolved Hide resolved
}
if (signalTy.getDtype().isF16()) {
complexDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
torch::torch_upstream::ScalarType::ComplexHalf);
} else if (signalTy.getDtype().isF32()) {
complexDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
torch::torch_upstream::ScalarType::ComplexFloat);
} else {
complexDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
torch::torch_upstream::ScalarType::ComplexDouble);
}

auto complexSignalTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
complexDtype.value());

// The onnx STFT op always passes in a float input, and if the input
// is intended to be complex, its shape will be [batch][length][2],
// where [...][0] is the real component, and [...][1] is the complex
// component. This complex input has to be made torch compatible before
// being passed into torch.stft, so it is necessary to call
// AtenViewAsComplexOp. In case of real input, the shape of the signal
// will be [batch][length][1], and therefore it will have to be squeezed
// at dim=2, before being passed into torch.stft.
if (signalShape[2] == 2) {
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
binder.getLoc(), complexSignalTy, signal);
} else {
Value two = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
auto newSignalTy = signalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
signalTy.getDtype());
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(), newSignalTy, signal, two);
}

// In case the window is not given, we use frameLength
// as the length of the window.
Value windowLen;
if (!windowIsNone) {
windowLen = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
} else {
windowLen = frameLengthItem;
}

Value falseVal =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value trueVal =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
auto stftTy = complexSignalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({resultShape[0], resultShape[2], resultShape[1]}),
complexSignalTy.getDtype());

// After torch.stft is called and the result is stored into the value
// stft, there is one thing to note: The resultType for the onnx op
// will have shape [batch][num_frames][length][2], while the shape of
// stft will be [batch][length][num_frames]. Before the value is
// converted to real through torch.view_as_real, we must permute the
// shape of stft to match the shape of resultType. Also, it is
// immaterial whether torch.view_as_real is called after or before the
// permutation; both outputs will be equivalent.
Value stft = rewriter.create<Torch::AtenStftOp>(
binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem,
windowLen, window, falseVal, onesided ? trueVal : falseVal,
trueVal);

auto permuteStftTy = complexSignalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({resultShape[0], resultShape[1], resultShape[2]}),
complexSignalTy.getDtype());
Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1});
Value permutedStft = rewriter.create<Torch::AtenPermuteOp>(
binder.getLoc(), permuteStftTy, stft, permuteDims);

rewriter.replaceOpWithNewOp<Torch::AtenViewAsRealOp>(
binder.op, resultType, permutedStft);
return success();
});
}
Loading
Loading