diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3a98df8..04daed8cd40a2 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -393,7 +393,7 @@ Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc, OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { // addi(x, 0) -> x - if (matchPattern(adaptor.getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero()) && getLhs() != *this) return getLhs(); // addi(subi(a, b), b) -> a diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 74e4a822b4fd7..93468dd79808f 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -555,7 +555,8 @@ bool GreedyPatternRewriteDriver::processWorklist() { replacements.push_back(constOp->getResult(0)); } - if (materializationSucceeded) { + if (materializationSucceeded && + !llvm::equal(replacements, op->getResults())) { rewriter.replaceOp(op, replacements); changed = true; LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 8e02c06a0a293..ed987e555926c 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1248,3 +1248,12 @@ func.func @test_materialize_failure() -> i64 { %u = index.castu %const : index to i64 return %u: i64 } + +// ----- + +// Make sure that the canonicalizer does not fold infinitely. + +// CHECK: %[[c0:.*]] = arith.constant 0 : index +%c0 = arith.constant 0 : index +// CHECK: %[[add:.*]] = arith.addi %[[c0]], %[[add]] : index +%0 = arith.addi %c0, %0 : index