diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td index a441fd82546e3..ddea3a7eae590 100644 --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -65,7 +65,8 @@ def AssertOp : CF_Op<"assert", //===----------------------------------------------------------------------===// def BranchOp : CF_Op<"br", [ - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, Terminator ]> { let summary = "Branch operation"; @@ -114,8 +115,8 @@ def BranchOp : CF_Op<"br", [ def CondBranchOp : CF_Op<"cond_br", [AttrSizedOperandSegments, - DeclareOpInterfaceMethods< - BranchOpInterface, ["getSuccessorForOperands"]>, + DeclareOpInterfaceMethods, WeightedBranchOpInterface, Pure, Terminator]> { let summary = "Conditional branch operation"; let description = [{ @@ -241,7 +242,8 @@ def CondBranchOp def SwitchOp : CF_Op<"switch", [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, Terminator]> { let summary = "Switch operation"; let description = [{ diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 06fa724e05fab..d32be0c63acc7 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -98,6 +98,15 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}], [{ return lhs == rhs; }] >, + InterfaceMethod<[{ + This method is called to returns the operands of this operation that + are passed to the specified successor's block arguments. If the successor + is not valid for this operation, or no operands are forwarded, an empty + ValueRange is returned. + }], + "ValueRange", "getSuccessorForwardOperands", + (ins "Block *":$successor), [{}],[{ return {};}] + >, ]; let verify = [{ diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h index 6ca4e13d159ac..125a7c6f6f5e7 100644 --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -90,6 +90,9 @@ class ReductionNode { /// corresponding region. LogicalResult initialize(ModuleOp parentModule, Region &parentRegion); + LogicalResult initialize(ModuleOp parentModule, Region &parentRegion, + IRMapping &mapper); + private: /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index 435c37bc95aac..f6eb0f05911b8 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -296,6 +296,12 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef) { return getDest(); } +ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) { + if (successor == getDest()) + return getDestOperands(); + return {}; +} + //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -583,6 +589,14 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { return nullptr; } +ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) { + if (successor == getTrueDest()) + return getTrueOperands(); + else if (successor == getFalseDest()) + return getFalseOperands(); + return {}; +} + //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// @@ -1034,6 +1048,16 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, .add(context); } +ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) { + if (successor == getDefaultDestination()) + return getDefaultOperands(); + SuccessorRange caseDests = getCaseDestinations(); + auto it = llvm::find(caseDests, successor); + if (it == caseDests.end()) + return {}; + return getCaseOperands(std::distance(caseDests.begin(), it)); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt index 68864e373c993..b18a4bca04fcb 100644 --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_library(MLIRReduce MLIRPass MLIRRewrite MLIRTransformUtils + MLIRControlFlowDialect DEPENDS MLIRReducerIncGen diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp index 11aeaf77b4642..897aae0becf33 100644 --- a/mlir/lib/Reducer/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule, return success(); } +LogicalResult ReductionNode::initialize(ModuleOp parentModule, + Region &targetRegion, + IRMapping &mapper) { + module = cast(parentModule->clone(mapper)); + // Use the first block of targetRegion to locate the cloned region. + Block *block = mapper.lookup(&*targetRegion.begin()); + region = block->getParent(); + return success(); +} + /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index 83497143d9669..12358f7d71688 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -14,7 +14,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionPatternInterface.h" @@ -24,6 +27,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "reduction-tree" namespace mlir { #define GEN_PASS_DEF_REDUCTIONTREEPASS @@ -184,6 +190,112 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion, return failure(); } +// Returns the first branching terminator (cond_br, switch, etc.) found in the +// region. +static Operation *getBranchTerminatorInRegion(Region ®ion) { + for (Block &block : region.getBlocks()) { + if (block.getNumSuccessors() > 1) + return block.getTerminator(); + } + return {}; +} + +/// Reduces the control flow in a region by iteratively forcing branching +/// terminators to point to a single successor. It evaluates each potential +/// branch path and commits the reduction that results in the smallest +/// "interesting" module. +static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module, + Region ®ion, + const Tester &test) { + std::pair initStatus = + test.isInteresting(module); + + // While exploring the reduction tree, we always branch from an interesting + // node. Thus the root node must be interesting. + if (initStatus.first != Tester::Interestingness::True) + return module.emitWarning() << "uninterested module will not be reduced"; + llvm::SpecificBumpPtrAllocator allocator; + + // We set the simplification level to Aggressive to enable block merging. + GreedyRewriteConfig config; + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive); + config.setUseTopDownTraversal(true); + + // Populate canonicalization patterns for cf ops. When all targets of a + // 'cf.cond_br' or 'cf.switch' point to the same block, they will be + // canonicalized into a 'cf.br'. + auto context = region.getContext(); + RewritePatternSet patterns(context); + cf::BranchOp::getCanonicalizationPatterns(patterns, context); + cf::CondBranchOp::getCanonicalizationPatterns(patterns, context); + cf::SwitchOp::getCanonicalizationPatterns(patterns, context); + FrozenRewritePatternSet fPatterns = std::move(patterns); + + ReductionNode *smallestNode = nullptr; + mlir::OpBuilder b(context); + while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) { + size_t numSuccessor = branchTerminator->getNumSuccessors(); + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + // Iterate through each successor of the branching terminator to try + // reducing the control flow to a single-path execution. + int branchIdx = -1; + for (int i = 0, e = numSuccessor; i < e; ++i) { + // We allocate memory on the heap because the object will be assigned to + // 'smallestNode'. + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, ranges, allocator); + mlir::IRMapping mapper; + if (failed(root->initialize(module, region, mapper))) + llvm_unreachable("unexpected initialization failure"); + Operation *tergetTerminator = mapper.lookup(branchTerminator); + Block *selectedBlock = tergetTerminator->getSuccessor(i); + auto branchOp = cast(tergetTerminator); + ValueRange selectedBlockOperands = + branchOp.getSuccessorForwardOperands(selectedBlock); + b.setInsertionPointAfter(tergetTerminator); + cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock, + selectedBlockOperands); + tergetTerminator->erase(); + + // Apply canonicalization patterns to collapse the now-redundant branches + (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns, + config); + root->update(test.isInteresting(root->getModule())); + + // Track the smallest "interesting" version of the IR found so far. + if (root->isInteresting() == Tester::Interestingness::True && + (smallestNode == nullptr || + root->getSize() < smallestNode->getSize())) { + smallestNode = root; + branchIdx = i; + } + } + + // If an interesting reduced branch was found, commit the change to the + // original region and re-apply patterns for a final cleanup. + if (branchIdx != -1) { + Block *selectedBlock = branchTerminator->getSuccessor(branchIdx); + auto branchOp = cast(branchTerminator); + ValueRange selectedBlockOperands = + branchOp.getSuccessorForwardOperands(selectedBlock); + b.setInsertionPointAfter(branchTerminator); + cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock, + selectedBlockOperands); + branchTerminator->erase(); + (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config); + } + } + + // If no branching terminators were found (skipping the while loop), + // there might still be opportunities for linear block merging or + // We apply patterns here as a final cleanup to ensure the region is fully + // simplified. + if (smallestNode == nullptr) + (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config); + return success(); +} + template static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, @@ -196,6 +308,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion, if (succeeded(eraseAllOpsInRegion(module, region, test))) return success(); + (void)eraseRedundantBlocksInRegion(module, region, test); + // In the second phase, we don't apply any patterns so that we only select the // range of operations to keep to the module stay interesting. if (failed(findOptimal(module, region, /*patterns=*/{}, test, diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir index 2aee89741b42b..b053a111e9a16 100644 --- a/mlir/test/mlir-reduce/reduction-tree.mlir +++ b/mlir/test/mlir-reduce/reduction-tree.mlir @@ -58,3 +58,68 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func.func @simple5() { return } + +// ----- + +// CHECK-LABEL: func @br_reduction +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + cf.br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = memref.alloc() : memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} +// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]]) + +// ----- + +// CHECK-LABEL: func @br_reduction_loop +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + // select ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + cf.br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = memref.alloc() : memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + // select ^bb4 + cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4 +^bb4: + return +} +// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]]) + +// ----- + +// CHECK-LABEL: func @switch_reduction +// CHECK-SAME: %[[ARG0:.*]]: i32, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cf.switch %arg0 : i32, [ + default: ^bb3(%arg1 : memref<2xf32>), + 0: ^bb1, + 1: ^bb2 + ] +^bb1: + cf.br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = memref.alloc() : memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} +// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])