From 33c1c10ae79d9f2bf8e6b75bf580981eb5dbb7fd Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Tue, 22 Oct 2024 22:13:53 +0800 Subject: [PATCH] [mlir][tosa] Add a verifier for `tosa.mul` This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types. --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 14 ++++---------- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 8 ++++++++ mlir/test/Dialect/Tosa/invalid.mlir | 9 +++++++++ mlir/test/Dialect/Tosa/ops.mlir | 2 +- 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 3bb5ceb0f4695..6e7d575ac26df 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ ); let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index c88f4db27c304..495f1b4f10b02 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - // tosa::MulOp - if (isa(op) && isa(elementTy)) { - if (dyn_cast(op).getShift() != 0) { - (void)rewriter.notifyMatchFailure(op, - "Cannot have shift value for float"); - return nullptr; - } - return rewriter.create(loc, resultTypes, args); - } - // tosa::IntDivOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); @@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( return rewriter.create(loc, resultTypes, one, args[0]); } + // tosa::MulOp + if (isa(op) && isa(elementTy)) + return rewriter.create(loc, resultTypes, args); + if (isa(op) && isa(elementTy)) { Value a = args[0]; Value b = args[1]; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1f3e19fe90c6d..631d3c48f2df0 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() { return success(); } +LogicalResult tosa::MulOp::verify() { + Type elementTy = getInput1().getType().getElementType(); + if (isa(elementTy) && getShift() != 0) + return emitOpError() << "require shift to be 0 for float type"; + + return success(); +} + LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TableOp::Adaptor adaptor, diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index b9298b6664353..f1b1707a0c40d 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } + +// ----- + +// CHECK-LABEL: test_mul_invalid_shift +func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}} + %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index a1600fd33c54b..a756588a7cc0d 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> }