diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index ab1e9c0ec7adc8..6d62e470706e53 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -625,11 +625,13 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder, mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { if (auto *v = getVal().getDefiningOp()) { if (auto box = mlir::dyn_cast(v)) { - if (!box.getSlice()) // Fold only if not sliced + // Fold only if not sliced + if (!box.getSlice() && box.getMemref().getType() == getType()) return box.getMemref(); } if (auto box = mlir::dyn_cast(v)) - return box.getMemref(); + if (box.getMemref().getType() == getType()) + return box.getMemref(); } return {}; } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 56d5e0fed76185..ff72becc8dfa77 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { setOperand(src); return getResult(); } + // trunci(zexti(a)) -> a // trunci(sexti(a)) -> a - return src; + if (srcType == dstType) + return src; } // trunci(trunci(a)) -> trunci(a)) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 7444f70a46e935..26c39ff3523434 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -771,6 +771,8 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } ShapedType inputTy = llvm::cast(getInput().getType()); \ if (!inputTy.hasRank()) \ return {}; \ + if (inputTy != getType()) \ + return {}; \ if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \ return getInput(); \ return {}; \ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1d3200bf5c8217..8a23adae3c00e6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1602,9 +1602,10 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { return llvm::isa(type) ? llvm::cast(type).getRank() : 0; }; + // If splat or broadcast from a scalar, just return the source scalar. unsigned broadcastSrcRank = getRank(source.getType()); - if (broadcastSrcRank == 0) + if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) return source; unsigned extractResultRank = getRank(extractOp.getType()); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 2cabfcd24d3559..d1565047658776 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -486,14 +486,11 @@ LogicalResult OpBuilder::tryFold(Operation *op, // Populate the results with the folded results. Dialect *dialect = op->getDialect(); - for (auto it : llvm::zip(foldResults, opResults.getTypes())) { + for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) { Type expectedType = std::get<1>(it); // Normal values get pushed back directly. if (auto value = llvm::dyn_cast_if_present(std::get<0>(it))) { - if (value.getType() != expectedType) - return cleanupFailure(); - results.push_back(value); continue; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 87be08712ea35b..a726790391a0c5 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -606,13 +606,30 @@ void Operation::setSuccessor(Block *block, unsigned index) { getBlockOperands()[index].set(block); } +#ifndef NDEBUG +/// Assert that the folded results (in case of values) have the same type as +/// the results of the given op. +static void checkFoldResultTypes(Operation *op, + SmallVectorImpl &results) { + if (!results.empty()) + for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults())) + if (auto value = ofr.dyn_cast()) + assert(value.getType() == opResult.getType() && + "folder produced value of incorrect type"); +} +#endif // NDEBUG + /// Attempt to fold this operation using the Op's registered foldHook. LogicalResult Operation::fold(ArrayRef operands, SmallVectorImpl &results) { // If we have a registered operation definition matching this one, use it to // try to constant fold the operation. - if (succeeded(name.foldHook(this, operands, results))) + if (succeeded(name.foldHook(this, operands, results))) { +#ifndef NDEBUG + checkFoldResultTypes(this, results); +#endif // NDEBUG return success(); + } // Otherwise, fall back on the dialect hook to handle it. Dialect *dialect = getDialect(); @@ -623,7 +640,12 @@ LogicalResult Operation::fold(ArrayRef operands, if (!interface) return failure(); - return interface->fold(this, operands, results); + LogicalResult status = interface->fold(this, operands, results); +#ifndef NDEBUG + if (succeeded(status)) + checkFoldResultTypes(this, results); +#endif // NDEBUG + return status; } LogicalResult Operation::fold(SmallVectorImpl &results) { diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index 90ee5ba51de3ad..eb4dcb251a2280 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -247,10 +247,6 @@ OperationFolder::processFoldResults(Operation *op, // Check if the result was an SSA value. if (auto repl = llvm::dyn_cast_if_present(foldResults[i])) { - if (repl.getType() != op->getResult(i).getType()) { - results.clear(); - return failure(); - } results.emplace_back(repl); continue; } diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir index bc463fefe65342..4f0095ed7e8cf4 100644 --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) { return %y, %z: i32, i32 } -// CHECK-LABEL: func @typemismatch - -func.func @typemismatch() -> i32 { - %c42 = arith.constant 42.0 : f32 - - // The "passthrough_fold" folder will naively return its operand, but we don't - // want to fold here because of the type mismatch. - - // CHECK: "test.passthrough_fold" - %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32) - return %0 : i32 -} - // CHECK-LABEL: test_dialect_canonicalizer func.func @test_dialect_canonicalizer() -> (i32) { %0 = "test.dialect_canonicalizable"() : () -> (i32) diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 6897b6f95f0d05..d8cf6e4719cede 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -310,16 +310,6 @@ builtin.module { // ----- -// The "passthrough_fold" folder will naively return its operand, but we don't -// want to fold here because of the type mismatch. -func.func @typemismatch(%arg: f32) -> i32 { - // expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}} - %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32) - "test.return"(%0) : (i32) -> () -} - -// ----- - // expected-remark @below {{applyPartialConversion failed}} module { func.func private @callee(%0 : f32) -> f32 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 21400a60e65321..a1b30705f16a98 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { return {}; } -OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) { - return getOperand(); -} - OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { int64_t sum = 0; if (auto value = dyn_cast_or_null(adaptor.getOp())) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 96f66c2ca06ecf..70ccc71883e3c1 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor let hasFolder = 1; } -// An op that always fold itself. -def TestPassthroughFold : TEST_Op<"passthrough_fold"> { - let arguments = (ins AnyType:$op); - let results = (outs AnyType); - let hasFolder = 1; -} - def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> { let arguments = (ins); let results = (outs I32);