diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index d2eceadb7d78b9..23ef80dfb02d85 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1519,51 +1519,98 @@ struct CombineIfs : public OpRewritePattern { if (!prevIf) return failure(); - if (nextIf.getCondition() != prevIf.getCondition()) - return failure(); + // Determine the logical then/else blocks when prevIf's + // condition is used. Null means the block does not exist + // in that case (e.g. empty else). If neither of these + // are set, the two conditions cannot be compared. + Block *nextThen = nullptr; + Block *nextElse = nullptr; + if (nextIf.getCondition() == prevIf.getCondition()) { + nextThen = nextIf.thenBlock(); + if (!nextIf.getElseRegion().empty()) + nextElse = nextIf.elseBlock(); + } + if (arith::XOrIOp notv = + nextIf.getCondition().getDefiningOp()) { + if (notv.getLhs() == prevIf.getCondition() && + matchPattern(notv.getRhs(), m_One())) { + nextElse = nextIf.thenBlock(); + if (!nextIf.getElseRegion().empty()) + nextThen = nextIf.elseBlock(); + } + } + if (arith::XOrIOp notv = + prevIf.getCondition().getDefiningOp()) { + if (notv.getLhs() == nextIf.getCondition() && + matchPattern(notv.getRhs(), m_One())) { + nextElse = nextIf.thenBlock(); + if (!nextIf.getElseRegion().empty()) + nextThen = nextIf.elseBlock(); + } + } - // Don't permit merging if a result of the first if is used - // within the second. - if (llvm::any_of(prevIf->getUsers(), - [&](Operation *user) { return nextIf->isAncestor(user); })) + if (!nextThen && !nextElse) return failure(); + SmallVector prevElseYielded; + if (!prevIf.getElseRegion().empty()) + prevElseYielded = prevIf.elseYield().getOperands(); + // Replace all uses of return values of op within nextIf with the + // corresponding yields + for (auto it : llvm::zip(prevIf.getResults(), + prevIf.thenYield().getOperands(), prevElseYielded)) + for (OpOperand &use : + llvm::make_early_inc_range(std::get<0>(it).getUses())) { + if (nextThen && nextThen->getParent()->isAncestor( + use.getOwner()->getParentRegion())) { + rewriter.startRootUpdate(use.getOwner()); + use.set(std::get<1>(it)); + rewriter.finalizeRootUpdate(use.getOwner()); + } else if (nextElse && nextElse->getParent()->isAncestor( + use.getOwner()->getParentRegion())) { + rewriter.startRootUpdate(use.getOwner()); + use.set(std::get<2>(it)); + rewriter.finalizeRootUpdate(use.getOwner()); + } + } + SmallVector mergedTypes(prevIf.getResultTypes()); llvm::append_range(mergedTypes, nextIf.getResultTypes()); IfOp combinedIf = rewriter.create( - nextIf.getLoc(), mergedTypes, nextIf.getCondition(), /*hasElse=*/false); + nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); rewriter.eraseBlock(&combinedIf.getThenRegion().back()); - YieldOp thenYield = prevIf.thenYield(); - YieldOp thenYield2 = nextIf.thenYield(); - - combinedIf.getThenRegion().getBlocks().splice( - combinedIf.getThenRegion().getBlocks().begin(), - prevIf.getThenRegion().getBlocks()); - - rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock()); - rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); - - SmallVector mergedYields(thenYield.getOperands()); - llvm::append_range(mergedYields, thenYield2.getOperands()); - rewriter.create(thenYield2.getLoc(), mergedYields); - rewriter.eraseOp(thenYield); - rewriter.eraseOp(thenYield2); + rewriter.inlineRegionBefore(prevIf.getThenRegion(), + combinedIf.getThenRegion(), + combinedIf.getThenRegion().begin()); + + if (nextThen) { + YieldOp thenYield = combinedIf.thenYield(); + YieldOp thenYield2 = cast(nextThen->getTerminator()); + rewriter.mergeBlocks(nextThen, combinedIf.thenBlock()); + rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); + + SmallVector mergedYields(thenYield.getOperands()); + llvm::append_range(mergedYields, thenYield2.getOperands()); + rewriter.create(thenYield2.getLoc(), mergedYields); + rewriter.eraseOp(thenYield); + rewriter.eraseOp(thenYield2); + } - combinedIf.getElseRegion().getBlocks().splice( - combinedIf.getElseRegion().getBlocks().begin(), - prevIf.getElseRegion().getBlocks()); + rewriter.inlineRegionBefore(prevIf.getElseRegion(), + combinedIf.getElseRegion(), + combinedIf.getElseRegion().begin()); - if (!nextIf.getElseRegion().empty()) { + if (nextElse) { if (combinedIf.getElseRegion().empty()) { - combinedIf.getElseRegion().getBlocks().splice( - combinedIf.getElseRegion().getBlocks().begin(), - nextIf.getElseRegion().getBlocks()); + rewriter.inlineRegionBefore(*nextElse->getParent(), + combinedIf.getElseRegion(), + combinedIf.getElseRegion().begin()); } else { YieldOp elseYield = combinedIf.elseYield(); - YieldOp elseYield2 = nextIf.elseYield(); - rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock()); + YieldOp elseYield2 = cast(nextElse->getTerminator()); + rewriter.mergeBlocks(nextElse, combinedIf.elseBlock()); rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 80af8174d2f592..86c478ec4eb683 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1119,6 +1119,79 @@ func @combineIfs4(%arg0 : i1, %arg2: i64) { // CHECK-NEXT: "test.secondCodeTrue"() : () -> () // CHECK-NEXT: } +// CHECK-LABEL: @combineIfsUsed +// CHECK-SAME: %[[arg0:.+]]: i1 +func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) { + %res = scf.if %arg0 -> i32 { + %v = "test.firstCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.firstCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + %res2 = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"(%res) : (i32) -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"(%res) : (i32) -> i32 + scf.yield %v2 : i32 + } + return %res, %res2 : i32, i32 +} +// CHECK-NEXT: %[[res:.+]]:2 = scf.if %[[arg0]] -> (i32, i32) { +// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32 +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"(%[[tval0]]) : (i32) -> i32 +// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32 +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"(%[[fval0]]) : (i32) -> i32 +// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32 + +// CHECK-LABEL: @combineIfsNot +// CHECK-SAME: %[[arg0:.+]]: i1 +func @combineIfsNot(%arg0 : i1, %arg2: i64) { + %true = arith.constant true + %not = arith.xori %arg0, %true : i1 + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + scf.if %not { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return +} + +// CHECK-NEXT: scf.if %[[arg0]] { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: } + +// CHECK-LABEL: @combineIfsNot2 +// CHECK-SAME: %[[arg0:.+]]: i1 +func @combineIfsNot2(%arg0 : i1, %arg2: i64) { + %true = arith.constant true + %not = arith.xori %arg0, %true : i1 + scf.if %not { + "test.firstCodeTrue"() : () -> () + scf.yield + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return +} + +// CHECK-NEXT: scf.if %[[arg0]] { +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: } // ----- // CHECK-LABEL: func @propagate_into_execute_region