diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h index 6ca4e13d159ac..125a7c6f6f5e7 100644 --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -90,6 +90,9 @@ class ReductionNode { /// corresponding region. LogicalResult initialize(ModuleOp parentModule, Region &parentRegion); + LogicalResult initialize(ModuleOp parentModule, Region &parentRegion, + IRMapping &mapper); + private: /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt index 68864e373c993..b18a4bca04fcb 100644 --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_library(MLIRReduce MLIRPass MLIRRewrite MLIRTransformUtils + MLIRControlFlowDialect DEPENDS MLIRReducerIncGen diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp index 11aeaf77b4642..897aae0becf33 100644 --- a/mlir/lib/Reducer/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule, return success(); } +LogicalResult ReductionNode::initialize(ModuleOp parentModule, + Region &targetRegion, + IRMapping &mapper) { + module = cast(parentModule->clone(mapper)); + // Use the first block of targetRegion to locate the cloned region. + Block *block = mapper.lookup(&*targetRegion.begin()); + region = block->getParent(); + return success(); +} + /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index 2244475e268fe..a4cad5d19c725 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -14,16 +14,23 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Reducer/Tester.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "reduction-tree" namespace mlir { #define GEN_PASS_DEF_REDUCTIONTREEPASS @@ -182,11 +189,190 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion, return failure(); } +/// Searches for an unvisited branch terminator within the given region based on +/// the specified conditionality. This helper scans blocks in the \p region to +/// find a terminator that has not yet been processed (not in \p visited). If +/// \p isConditional is true, it looks for terminators with multiple successors +/// (e.g., cf.cond_br). Otherwise, it looks for single-successor terminators +/// (e.g., cf.br). +static Operation *getBranchTerminatorInRegion(Region ®ion, + DenseSet &visited, + bool isConditional = true) { + auto it = llvm::find_if(region.getBlocks(), [&](Block &block) { + if (!block.mightHaveTerminator()) + return false; + size_t numSucc = block.getNumSuccessors(); + Operation *term = block.getTerminator(); + return !visited.contains(term) && + (isConditional ? numSucc > 1 : numSucc == 1); + }); + return it != region.end() ? it->getTerminator() : nullptr; +} + +/// Prunes unreachable blocks from the CFG using the \p worklist. This function +/// iteratively removes blocks that have no predecessors. When a block is +/// erased, its successors are added to the worklist as they may consequently +/// become unreachable. This ensures a cascading deletion of dead-end paths in +/// the control flow graph. +static void pruneCFGEdges(SetVector &workList, IRRewriter &rewriter) { + while (!workList.empty()) { + Block *b = workList.front(); + workList.erase(workList.begin()); + if (b->hasNoPredecessors()) { + for (Block *it : b->getSuccessors()) + workList.insert(it); + rewriter.eraseBlock(b); + } + } +} + +/// Reduces the control flow in a region by iteratively forcing branching +/// terminators to point to a single successor. It evaluates each potential +/// branch path and commits the reduction that results in the smallest +/// "interesting" module. +static LogicalResult reduceConditionalsInRegion(ModuleOp module, Region ®ion, + const Tester &test) { + std::pair initStatus = + test.isInteresting(module); + + if (initStatus.first != Tester::Interestingness::True) + return module.emitWarning() << "uninterested module will not be reduced"; + llvm::SpecificBumpPtrAllocator allocator; + + ReductionNode *smallestNode = nullptr; + mlir::IRRewriter rewriter(region.getContext()); + DenseSet visited; + + // This loop attempts to convert conditional branch operations into + // unconditional ones. + while (Operation *branchTerminator = + getBranchTerminatorInRegion(region, visited)) { + size_t numSuccessor = branchTerminator->getNumSuccessors(); + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + // Iterate through each successor of the branching terminator to try + // reducing the control flow to a single-path execution. + int branchIdx = -1; + for (int i = 0, e = numSuccessor; i < e; ++i) { + // We allocate memory on the heap because the object will be assigned to + // 'smallestNode'. + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, ranges, allocator); + mlir::IRMapping mapper; + if (failed(root->initialize(module, region, mapper))) + llvm_unreachable("unexpected initialization failure"); + + Operation *tergetTerminator = mapper.lookup(branchTerminator); + Block *selectedBlock = tergetTerminator->getSuccessor(i); + auto branchOp = cast(tergetTerminator); + mlir::SuccessorOperands selectedBlockOperands = + branchOp.getSuccessorOperands(i); + rewriter.setInsertionPointAfter(tergetTerminator); + cf::BranchOp::create(rewriter, tergetTerminator->getLoc(), selectedBlock, + selectedBlockOperands.getForwardedOperands()); + auto succs = llvm::to_vector(tergetTerminator->getSuccessors()); + succs.erase(succs.begin() + i); + SetVector workList(succs.begin(), succs.end()); + rewriter.eraseOp(tergetTerminator); + pruneCFGEdges(workList, rewriter); + root->update(test.isInteresting(root->getModule())); + if (root->isInteresting() == Tester::Interestingness::True && + (smallestNode == nullptr || + root->getSize() < smallestNode->getSize())) { + smallestNode = root; + branchIdx = i; + } + } + + if (branchIdx != -1) { + Block *selectedBlock = branchTerminator->getSuccessor(branchIdx); + auto branchOp = cast(branchTerminator); + mlir::SuccessorOperands selectedBlockOperands = + branchOp.getSuccessorOperands(branchIdx); + rewriter.setInsertionPointAfter(branchTerminator); + cf::BranchOp::create(rewriter, branchTerminator->getLoc(), selectedBlock, + selectedBlockOperands.getForwardedOperands()); + + auto succs = llvm::to_vector(branchOp->getSuccessors()); + succs.erase(succs.begin() + branchIdx); + SetVector workList(succs.begin(), succs.end()); + rewriter.eraseOp(branchOp); + pruneCFGEdges(workList, rewriter); + } else { + // Insert 'branchTerminator' into visited to prevent it from being + // processed again. + visited.insert(branchTerminator); + } + } + return success(); +} + +/// Simplifies the Control Flow Graph (CFG) by merging blocks that have a +/// single-successor / single-predecessor relationship. This function leverages +/// the canonicalization patterns of 'cf.br' to perform the merge +static LogicalResult reduceBlockMergeInRegion(ModuleOp module, Region ®ion, + const Tester &test) { + std::pair initStatus = + test.isInteresting(module); + + if (initStatus.first != Tester::Interestingness::True) + return module.emitWarning() << "uninterested module will not be reduced"; + llvm::SpecificBumpPtrAllocator allocator; + + GreedyRewriteConfig config; + auto context = region.getContext(); + RewritePatternSet patterns(context); + cf::BranchOp::getCanonicalizationPatterns(patterns, context); + FrozenRewritePatternSet fPatterns = std::move(patterns); + + mlir::IRRewriter rewriter(context); + DenseSet visited; + while (Operation *branchTerminator = + getBranchTerminatorInRegion(region, visited, false)) { + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, ranges, allocator); + mlir::IRMapping mapper; + if (failed(root->initialize(module, region, mapper))) + llvm_unreachable("unexpected initialization failure"); + Operation *tergetTerminator = mapper.lookup(branchTerminator); + bool changed = false; + (void)applyOpPatternsGreedily(tergetTerminator, fPatterns, config, + &changed); + root->update(test.isInteresting(root->getModule())); + + // If the changed variable is false, it indicates that the pattern failed to + // apply. We should insert it into visited to prevent it from being + // processed again. + if (changed && root->isInteresting() == Tester::Interestingness::True) + (void)applyOpPatternsGreedily(branchTerminator, fPatterns, config); + else + visited.insert(branchTerminator); + } + return success(); +} + +static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module, + Region ®ion, + const Tester &test) { + /// We separate the reduction control flow graph process into 2 steps. + + // we attempts to simplify conditional branches into unconditional ones by + // picking the "interesting" path. + (void)reduceConditionalsInRegion(module, region, test); + + // We merge redundant blocks that have single-successor/single-predecessor + // relationships using canonicalization patterns. + (void)reduceBlockMergeInRegion(module, region, test); + return success(); +} + template static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test) { - // We separate the reduction process into 3 steps, the first one is to erase + // We separate the reduction process into 4 steps, the first one is to erase // redundant operations and the second one is to apply the reducer patterns. // In the first phase, we attempt to erase all operations within the entire @@ -194,12 +380,16 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion, if (succeeded(eraseAllOpsInRegion(module, region, test))) return success(); - // In the second phase, we don't apply any patterns so that we only select the + // In the second phase, we attempt to eliminate redundant blocks. This reduces + // the program's execution paths. + (void)eraseRedundantBlocksInRegion(module, region, test); + + // In the third phase, we don't apply any patterns so that we only select the // range of operations to keep to the module stay interesting. if (failed(findOptimal(module, region, /*patterns=*/{}, test, /*eraseOpNotInRange=*/true))) return failure(); - // In the third phase, we suppose that no operation is redundant, so we try + // In the fourth phase, we suppose that no operation is redundant, so we try // to rewrite the operation into simpler form. return findOptimal(module, region, patterns, test, /*eraseOpNotInRange=*/false); diff --git a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir index b235ca14d693a..3fd7d13a10d13 100644 --- a/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir +++ b/mlir/test/mlir-reduce/reduction-tree/reduction-tree.mlir @@ -58,3 +58,101 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func.func @simple5() { return } + +// ----- + +// CHECK-LABEL: func @br_reduction +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + cf.br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = memref.alloc() : memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} +// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]]) + +// ----- + +// CHECK-LABEL: func @br_reduction_loop +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { +func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + // select ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + cf.br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = memref.alloc() : memref<2xf32> + cf.br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + // select ^bb4 + cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4 +^bb4: + return +} +// CHECK: cf.br ^bb1(%[[ARG1]] : memref<2xf32>) +// CHECK: ^bb1(%[[VAL_0:.*]]: memref<2xf32>): +// CHECK: "test.op_crash"(%[[VAL_0]], %[[ARG2]]) +// CHECK: cf.br ^bb1(%[[VAL_0]] : memref<2xf32>) + +// ----- + +// CHECK-LABEL: func @switch_reduction +// CHECK-SAME: %[[ARG0:.*]]: i32, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<3xf32>) +func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<3xf32>) { + cf.switch %arg0 : i32, [ + default: ^bb3(%arg1 : memref<2xf32>), + 0: ^bb1(%arg2: memref<3xf32>), + 1: ^bb2 + ] +^bb1(%0: memref<3xf32>): + cf.br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %1 = memref.alloc() : memref<2xf32> + cf.br ^bb3(%1 : memref<2xf32>) +^bb3(%2: memref<2xf32>): + "test.op_crash"(%2, %arg2) : (memref<2xf32>, memref<3xf32>) -> () + return +} +// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]]) + +// ----- + +// This test verifies the ability to reduce unreachable code. + +// CHECK-LABEL: func @unreachable_code +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) +func.func @unreachable_code(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cf.br ^bb1 +^bb1: + cf.cond_br %arg0, ^bb2, ^bb3 +^bb2: + cf.br ^bb1 +^bb3: + %alloc = memref.alloc() : memref<2xf32> + cf.br ^bb1 +^bb4(%0: memref<2xf32>): + "test.op_crash"(%0, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + cf.cond_br %arg0, ^bb4(%0 : memref<2xf32>), ^bb5 +^bb5: + return +} +// CHECK: cf.br ^bb1 +// CHECK: ^bb1: +// CHECK: cf.br ^bb1 +// CHECK: ^bb2(%[[VAL_0:.*]]: memref<2xf32>): +// CHECK: "test.op_crash"(%[[VAL_0]], %[[ARG2]]) +// CHECK: cf.br ^bb2(%[[VAL_0]] : memref<2xf32>)