diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td index 104f2741e5678..cbcc2a017ac3a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -250,6 +250,8 @@ def Tosa_MaxShapeOp : Tosa_ElementwiseShapeOp<"max_shape", [Pure]> { ); let results = (outs Tosa_Shape:$output); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -268,6 +270,8 @@ def Tosa_MinShapeOp : Tosa_ElementwiseShapeOp<"min_shape", [Pure]> { ); let results = (outs Tosa_Shape:$output); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index ed1584f93a367..42033ce8a3b02 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1116,7 +1116,33 @@ struct ModFoldAdaptor { } }; -struct FoldGreaterAdaptor { +struct MaxFoldAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + bool isUnsigned) { + if (lhs.getBitWidth() != rhs.getBitWidth()) + return failure(); + return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs; + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return lhs >= rhs ? lhs : rhs; + } +}; + +struct MinFoldAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + bool isUnsigned) { + if (lhs.getBitWidth() != rhs.getBitWidth()) + return failure(); + return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs; + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return lhs <= rhs ? lhs : rhs; + } +}; + +struct GreaterFoldAdaptor { static FailureOr fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned) { return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs)); @@ -1127,7 +1153,7 @@ struct FoldGreaterAdaptor { } }; -struct FoldGreaterEqualAdaptor { +struct GreaterEqualFoldAdaptor { static FailureOr fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned) { return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs)); @@ -1138,7 +1164,7 @@ struct FoldGreaterEqualAdaptor { } }; -struct FoldEqualAdaptor { +struct EqualFoldAdaptor { static FailureOr fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned) { return APInt(1, lhs == rhs); @@ -1247,8 +1273,9 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { APInt l = lhsAttr.getSplatValue(); APInt r = rhsAttr.getSplatValue(); if (!r.isZero()) { + auto intTy = dyn_cast(resultETy); auto const result = - DivFoldAdaptor::fold(l, r, /*isUnsigned*/ false); + DivFoldAdaptor::fold(l, r, intTy.isUnsigned()); if (failed(result)) return {}; return DenseElementsAttr::get(resultTy, result.value()); @@ -1396,7 +1423,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder(lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { @@ -1409,7 +1436,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder(lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { @@ -1432,7 +1459,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder(lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { @@ -1909,3 +1936,11 @@ OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) { OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) { return binaryFold(this); } + +OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold(this); +} + +OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold(this); +} diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index a8bcb9d52f000..c3186279a30ae 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -875,3 +875,25 @@ func.func @test_no_fold_mod_shape_negative_overflow() -> !tosa.shape<6> { %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> return %c : !tosa.shape<6> } + +// ----- + +// CHECK-LABEL: @test_max_shape +// CHECK: tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 5]> : tensor<6xindex>} : () -> !tosa.shape<6> +func.func @test_max_shape() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 5]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.max_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +} + +// ----- + +// CHECK-LABEL: @test_min_shape +// CHECK: tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> +func.func @test_min_shape() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 5]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.min_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +}