diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index ea4414fc1890e..f8588aa0ace0f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -888,7 +888,7 @@ LogicalResult tosa::RFFT2dOp::verify() { // Output width dimension expected to be input_width / 2 + 1 const int64_t outputWidth = outputType.getDimSize(2); if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) && - (outputWidth - 1) * 2 != width) + (outputWidth != (width / 2) + 1)) return emitOpError( "expected output width to be equal to input_width / 2 + 1, got ") << outputWidth; diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 4c2cda8d9c027..600afe2abbff2 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -175,6 +175,13 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tenso return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32> } +// ----- +// CHECK-LABEL: rfft2d_width1 +func.func @test_rfft2d_width1(%arg0: tensor<1x1x1xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>) { + %0, %1 = tosa.rfft2d %arg0 : (tensor<1x1x1xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>) + return %0, %1 : tensor<1x1x1xf32>, tensor<1x1x1xf32> +} + // ----- // CHECK-LABEL: rfft2d_with_local_bound func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {