Skip to content

Commit 94f255c

Browse files
lhutton1jpienaar
authored andcommitted
[mlir][tosa] Add RFFT2d operation
Adds the RFFT2d TOSA operation and supporting shape inference function. Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I7e49c47cdd846cdc1b187545ef76d5cda2d5d9ad Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D142336
1 parent 2b02df7 commit 94f255c

File tree

5 files changed

+88
-0
lines changed

5 files changed

+88
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
1515

1616
#include "mlir/Dialect/Traits.h"
17+
#include "mlir/IR/OpImplementation.h"
1718
#include "mlir/Interfaces/InferTypeOpInterface.h"
1819
#include "mlir/Interfaces/LoopLikeInterface.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,34 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
270270
let hasCanonicalizer = 1;
271271
}
272272

273+
//===----------------------------------------------------------------------===//
274+
// Operator: rfft2d
275+
//===----------------------------------------------------------------------===//
276+
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
277+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
278+
["inferReturnTypeComponents"]>,
279+
Pure]> {
280+
let summary = "Performs RFFT2D operation on the input.";
281+
282+
let description = [{
283+
Performs a batched 2D real-valued Fast Fourier Transform over the input where
284+
the input tensor consists of real values producing complex valued output. The
285+
complex output values will be split into the output_real and output_imag
286+
tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only
287+
calculate the first half of the final output axis. Imaginary values with
288+
locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero.
289+
}];
290+
291+
let arguments = (ins
292+
Tosa_Tensor3D:$input
293+
);
294+
295+
let results = (outs
296+
Tosa_Tensor3D:$output_real,
297+
Tosa_Tensor3D:$output_imag
298+
);
299+
}
300+
273301
//===----------------------------------------------------------------------===//
274302
// Operator: transpose_conv2d
275303
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,31 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
387387
return success();
388388
}
389389

390+
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
391+
MLIRContext *context, ::std::optional<Location> location,
392+
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
393+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
394+
ShapeAdaptor inputShape = operands.getShape(0);
395+
396+
if (!inputShape.hasRank())
397+
return failure();
398+
399+
llvm::SmallVector<int64_t> outputShape;
400+
outputShape.resize(3, ShapedType::kDynamic);
401+
outputShape[0] = inputShape.getDimSize(0);
402+
outputShape[1] = inputShape.getDimSize(1);
403+
int64_t inWidth = inputShape.getDimSize(2);
404+
405+
// Note that we can support this calculation symbolically
406+
// in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
407+
if (inWidth != ShapedType::kDynamic)
408+
outputShape[2] = inWidth / 2 + 1;
409+
410+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
411+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
412+
return success();
413+
}
414+
390415
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
391416
MLIRContext *context, ::std::optional<Location> location,
392417
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32
7272
return %0 : tensor<1x32x32x8xf32>
7373
}
7474

75+
// -----
76+
// CHECK-LABEL: rfft2d
77+
func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
78+
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
79+
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
80+
}
81+
7582
// -----
7683
// CHECK-LABEL: transpose_conv2d
7784
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,3 +1189,30 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
11891189
}) : (tensor<i32>, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>)
11901190
return
11911191
}
1192+
1193+
// -----
1194+
1195+
// CHECK-LABEL: @test_static_rfft2d
1196+
func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () {
1197+
// CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>)
1198+
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x8xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1199+
return
1200+
}
1201+
1202+
// -----
1203+
1204+
// CHECK-LABEL: @test_dynamic_batch_rfft2d
1205+
func.func @test_dynamic_batch_rfft2d(%arg0 : tensor<?x2x4xf32>) -> () {
1206+
// CHECK: -> (tensor<?x2x3xf32>, tensor<?x2x3xf32>)
1207+
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x2x4xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1208+
return
1209+
}
1210+
1211+
// -----
1212+
1213+
// CHECK-LABEL: @test_dynamic_width_rfft2d
1214+
func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () {
1215+
// CHECK: -> (tensor<5x2x?xf32>, tensor<5x2x?xf32>)
1216+
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1217+
return
1218+
}

0 commit comments

Comments
 (0)