diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index e754a04b0903a..44cbb458d94fe 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -22,6 +22,7 @@ #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 6ed85c611983a..0b33ecb48b7f2 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -19,6 +19,7 @@ include "mlir/IR/RegionKindInterface.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -78,6 +79,7 @@ def ConditionOp : SCF_Op<"condition", [ def ExecuteRegionOp : SCF_Op<"execute_region", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, RecursiveMemoryEffects]> { let summary = "operation that executes its region exactly once"; let description = [{ @@ -161,6 +163,7 @@ def ForOp : SCF_Op<"for", ConditionallySpeculatable, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects]> { let summary = "for operation"; @@ -329,6 +332,7 @@ def ForallOp : SCF_Op<"forall", [ RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DestinationStyleOpInterface, HasParallelRegion ]> { @@ -701,6 +705,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> { let summary = "if-then-else operation"; @@ -806,6 +811,7 @@ def ParallelOp : SCF_Op<"parallel", "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps"]>, RecursiveMemoryEffects, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::ReduceOp">, HasParallelRegion]> { let summary = "parallel for operation"; @@ -904,6 +910,7 @@ def ParallelOp : SCF_Op<"parallel", def ReduceOp : SCF_Op<"reduce", [ Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "reduce operation for scf.parallel"; @@ -987,6 +994,7 @@ def WhileOp : SCF_Op<"while", ["getEntrySuccessorOperands", "getSuccessorInputs"]>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, RecursiveMemoryEffects, SingleBlock]> { let summary = "a generic 'while' loop"; let description = [{ @@ -1136,6 +1144,7 @@ def WhileOp : SCF_Op<"while", def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::YieldOp">, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, InterfaceMethod<[{ Hook triggered once the promotion of a slot is complete. This can - also clean up the created default value if necessary. + also clean up the created default value if necessary. The default + value may be a null value if no default value was created. This will only be called for slots declared by this operation. Must return a new promotable allocation op if this operation produced @@ -240,7 +243,7 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { operation will be called after the main mutation stage finishes (i.e., after all ops have been processed with `removeBlockingUses`). - Operations should only the replaced values if the intended + Operations should only visit the replaced values if the intended transformation applies to all the replaced values. Furthermore, replaced values must not be deleted. }], "bool", "requiresReplacedValues", (ins), [{}], @@ -263,6 +266,98 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> { ]; } +def PromotableRegionOpInterface + : OpInterface<"PromotableRegionOpInterface"> { + let description = [{ + Describes an operation for which memory slots can be promoted to SSA values + within the operation's regions. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns true when the provided region of the operation can be analyzed + for promotion. The provided region must be a child of the operation's + region. + The `hasValueStores` flag indicates whether the region contains + store-like operations that write to the memory slot. + }], "bool", "isRegionPromotable", + (ins + "const ::mlir::MemorySlot &":$slot, + "::mlir::Region *":$region, + "bool":$hasValueStores + ) + >, + InterfaceMethod<[{ + Called before processing the nested regions in the operation. + + Based on the `reachingDef` value representing the value in the memory + slot at the entry into the operation, `setupPromotion` fills in the + `regionsToProcess` with the reaching definition at the entry of + all its promotable regions. + + `setupPromotion` is allowed to mutate + the operation in place, including its nested regions, but cannot + delete existing operations or modify successor-bearing terminators. + Other mutations are not allowed. + + The `hasValueStores` flag indicates whether the regions contain + `store`-like operations that write to the memory slot. This field can be + used to reduce the amount of book-keeping required to track the reaching + definitions. + }], "void", "setupPromotion", + (ins + "const ::mlir::MemorySlot &":$slot, + "::mlir::Value":$reachingDef, + "bool":$hasValueStores, + "::llvm::SmallMapVector<::mlir::Region *, ::mlir::Value, 2> &":$regionsToProcess + ) + >, + InterfaceMethod<[{ + Called once the reaching definitions have been computed for all the + regions, but before the actual removal of the blocking uses. + + Returns the new reaching definition at the exit of the operation. For + this purpose, mutation of the operation is allowed under the following + constraints: + 1. If a region is deleted, all of its content must have been moved out + (not copied) to a new empty region that remains valid after the + deletion. + 2. Mutation must not change control flow within or between existing or + moved regions. This includes adding, removing or reordering blocks. + 3. Mutation must not modify or add operations that interact with the + value of the slot. + + As an example, in order to add new results to the region operation, it + is allowed to clone the operation without regions, move (without + copying) the old region content into the new regions, and delete the + original operation. + + The `entryReachingDef` is the reaching definition at the entry of the + region operation. + + The `reachingAtBlockEnd` map contains the reaching definitions after all + the terminators within the regions of the operation. If a block of the + region is not present in the map, it is either dead code or within a + region that does not interact with the value of the slot. + + The `hasValueStores` flag indicates whether the regions contain + `store`-like operations that write to the memory slot. This field can be + used to reduce the amount of book-keeping required to track the reaching + definitions. + }], + "::mlir::Value", "finalizePromotion", + (ins + "const ::mlir::MemorySlot &":$slot, + "::mlir::Value":$entryReachingDef, + "bool":$hasValueStores, + "const ::llvm::DenseMap<::mlir::Block *, ::mlir::Value> &":$reachingAtBlockEnd, + "::mlir::OpBuilder &":$builder + ) + >, + ]; +} + def DestructurableAllocationOpInterface : OpInterface<"DestructurableAllocationOpInterface"> { let description = [{ @@ -304,7 +399,7 @@ def DestructurableAllocationOpInterface >, InterfaceMethod<[{ Hook triggered once the destructuring of a slot is complete, meaning the - original slot is no longer being refered to and could be deleted. + original slot is no longer being referred to and could be deleted. This will only be called for slots declared by this operation. Must return a new destructurable allocation op if this hook creates @@ -328,7 +423,7 @@ def SafeMemorySlotAccessOpInterface let methods = [ InterfaceMethod<[{ Returns whether all accesses in this operation to the provided slot are - done in a safe manner. To be safe, the access most only access the slot + done in a safe manner. To be safe, the access must only access the slot inside the bounds that its type implies. If the safety of the accesses depends on the safety of the accesses to diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index 540423831937e..6748e2cf71804 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -82,7 +82,7 @@ std::optional memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue, OpBuilder &builder) { - if (defaultValue.use_empty()) + if (defaultValue && defaultValue.use_empty()) defaultValue.getDefiningOp()->erase(); this->erase(); return std::nullopt; diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt index b111117410ba3..fca28c5209e2d 100644 --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRSCFDialect SCF.cpp DeviceMappingInterface.cpp + MemorySlot.cpp ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp new file mode 100644 index 0000000000000..3d61476df6014 --- /dev/null +++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp @@ -0,0 +1,355 @@ +//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" + +using namespace mlir; +using namespace mlir::scf; + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +/// Adds the corresponding reaching definition to the terminator of the block if +/// the terminator is of the provided type. +template +static void +updateTerminator(Block *block, Value defaultReachingDef, + const llvm::DenseMap &reachingAtBlockEnd) { + Operation *terminator = block->getTerminator(); + if (!isa(terminator)) + return; + Value blockReachingDef = reachingAtBlockEnd.lookup(block); + if (!blockReachingDef) { + // Block is dead code or the region is not using the slot, so we use the + // default provided reaching definition. + blockReachingDef = defaultReachingDef; + } + terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef}); +} + +/// Creates a shallow copy of an operation with new result types, moving the +/// regions out of the original operation and deleting the original operation. +static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op, + TypeRange resultTypes) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Operation *newOp = + mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands()); + rewriter.startOpModification(newOp); + rewriter.startOpModification(op); + for (unsigned int i : llvm::seq(op->getNumRegions())) + newOp->getRegion(i).takeBody(op->getRegion(i)); + rewriter.finalizeOpModification(op); + rewriter.finalizeOpModification(newOp); + + SmallVector replacementValues(newOp->getResults().drop_back()); + rewriter.replaceAllOpUsesWith(op, replacementValues); + rewriter.eraseOp(op); + return newOp; +} + +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + return true; +} + +void ExecuteRegionOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + regionsToProcess.insert({&getRegion(), reachingDef}); +} + +Value ExecuteRegionOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + if (!hasValueStores) + return reachingDef; + + // Update the yield terminators to return the newly defined reaching + // definition. + for (Block &block : getRegion().getBlocks()) + updateTerminator(&block, reachingDef, reachingAtBlockEnd); + + SmallVector resultTypes(getResultTypes()); + resultTypes.push_back(slot.elemType); + + IRRewriter rewriter(builder); + Operation *newOp = + replaceWithNewResults(rewriter, getOperation(), resultTypes); + return newOp->getResults().back(); +} + +//===----------------------------------------------------------------------===// +// ForOp +//===----------------------------------------------------------------------===// + +bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + return true; +} + +void ForOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + Region &bodyRegion = getBodyRegion(); + if (!hasValueStores) { + regionsToProcess.insert({&bodyRegion, reachingDef}); + return; + } + + getInitArgsMutable().append(reachingDef); + bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc()); + regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()}); +} + +Value ForOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + if (!hasValueStores) + return reachingDef; + + // Update the yield terminator to return the newly defined reaching + // definition. + updateTerminator(getBody(), reachingDef, reachingAtBlockEnd); + + SmallVector resultTypes(getResultTypes()); + resultTypes.push_back(slot.elemType); + + IRRewriter rewriter(builder); + Operation *newOp = + replaceWithNewResults(rewriter, getOperation(), resultTypes); + return newOp->getResults().back(); +} + +//===----------------------------------------------------------------------===// +// ForallOp +//===----------------------------------------------------------------------===// + +bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + // The ForallOp body can be ran in parallel, thus does not support sequenced + // value passing. Therefore only loads can be handled. + return !hasValueStores; +} + +void ForallOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + assert(!hasValueStores && "ForallOp does not support stores"); + regionsToProcess.insert({&getBodyRegion(), reachingDef}); +} + +Value ForallOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + assert(!hasValueStores && "ForallOp does not support stores"); + return reachingDef; +} + +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + return true; +} + +void IfOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + regionsToProcess.insert({&getThenRegion(), reachingDef}); + regionsToProcess.insert({&getElseRegion(), reachingDef}); +} + +Value IfOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + if (!hasValueStores) + return reachingDef; + + IRRewriter rewriter(builder); + + // Update the yield terminators to return the newly defined reaching + // definition. + updateTerminator(&getThenRegion().back(), reachingDef, + reachingAtBlockEnd); + if (getElseRegion().hasOneBlock()) { + updateTerminator(&getElseRegion().back(), reachingDef, + reachingAtBlockEnd); + } else { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.createBlock(&getElseRegion()); + YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef); + } + + SmallVector resultTypes(getResultTypes()); + resultTypes.push_back(slot.elemType); + + Operation *newOp = + replaceWithNewResults(rewriter, getOperation(), resultTypes); + return newOp->getResults().back(); +} + +//===----------------------------------------------------------------------===// +// IndexSwitchOp +//===----------------------------------------------------------------------===// + +bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + return true; +} + +void IndexSwitchOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + regionsToProcess.insert({&getDefaultRegion(), reachingDef}); + for (Region &caseRegion : getCaseRegions()) + regionsToProcess.insert({&caseRegion, reachingDef}); +} + +Value IndexSwitchOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + if (!hasValueStores) + return reachingDef; + + IRRewriter rewriter(builder); + + // Update the yield terminators to return the newly defined reaching + // definition. + updateTerminator(&getDefaultRegion().back(), reachingDef, + reachingAtBlockEnd); + for (Region &caseRegion : getCaseRegions()) + updateTerminator(&caseRegion.back(), reachingDef, + reachingAtBlockEnd); + + SmallVector resultTypes(getResultTypes()); + resultTypes.push_back(slot.elemType); + + Operation *newOp = + replaceWithNewResults(rewriter, getOperation(), resultTypes); + return newOp->getResults().back(); +} + +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + // The ParallelOp body can be ran in parallel, thus does not support sequenced + // value passing. Therefore only loads can be handled. + return !hasValueStores; +} + +void ParallelOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + assert(!hasValueStores && "ParallelOp does not support stores"); + regionsToProcess.insert({&getBodyRegion(), reachingDef}); +} + +Value ParallelOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + assert(!hasValueStores && "ParallelOp does not support stores"); + return reachingDef; +} + +//===----------------------------------------------------------------------===// +// ReduceOp +//===----------------------------------------------------------------------===// + +bool ReduceOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + // The ReduceOp body can be ran in parallel, thus does not support sequenced + // value passing. Therefore only loads can be handled. + return !hasValueStores; +} + +void ReduceOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + assert(!hasValueStores && "ReduceOp does not support stores"); + for (Region &reduction : getReductions()) + regionsToProcess.insert({&reduction, reachingDef}); +} + +Value ReduceOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + assert(!hasValueStores && "ReduceOp does not support stores"); + return reachingDef; +} + +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +bool WhileOp::isRegionPromotable(const MemorySlot &slot, Region *region, + bool hasValueStores) { + return true; +} + +void WhileOp::setupPromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + llvm::SmallMapVector ®ionsToProcess) { + Region &beforeRegion = getBefore(); + Region &afterRegion = getAfter(); + if (!hasValueStores) { + regionsToProcess.insert({&beforeRegion, reachingDef}); + regionsToProcess.insert({&afterRegion, reachingDef}); + return; + } + + getInitsMutable().append(reachingDef); + + beforeRegion.addArgument(slot.elemType, slot.ptr.getLoc()); + regionsToProcess.insert({&beforeRegion, beforeRegion.getArguments().back()}); + + afterRegion.addArgument(slot.elemType, slot.ptr.getLoc()); + regionsToProcess.insert({&afterRegion, afterRegion.getArguments().back()}); +} + +Value WhileOp::finalizePromotion( + const MemorySlot &slot, Value reachingDef, bool hasValueStores, + const llvm::DenseMap &reachingAtBlockEnd, + OpBuilder &builder) { + if (!hasValueStores) + return reachingDef; + + // Update the yield terminators to return the newly defined reaching + // definition. + updateTerminator(&getBefore().back(), + getBefore().getArguments().back(), + reachingAtBlockEnd); + updateTerminator( + &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd); + + SmallVector resultTypes(getResultTypes()); + resultTypes.push_back(slot.elemType); + + IRRewriter rewriter(builder); + Operation *newOp = + replaceWithNewResults(rewriter, getOperation(), resultTypes); + return newOp->getResults().back(); +} diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index b3057129fb9fd..5bd0c70f1d33f 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -52,8 +52,19 @@ using namespace mlir; /// this, the value stored can be well defined at block boundaries, allowing /// the propagation of replacement through blocks. /// -/// This pass computes this transformation in four main steps. The two first -/// steps are performed during an analysis phase that does not mutate IR. +/// The way regions are handled in the transformation is by offering an +/// interface to express the behavior of the allocation value at the edges of +/// the regions: from a particular definition reaching the region operation, the +/// operation will specify what the reaching definition at the entry of its +/// regions are (potentially mutating itself, for example to add region +/// arguments). Likewise, provided a reaching definition at the end of the +/// blocks in the regions, the region operation will provide the reaching +/// definition right after itself. +/// +/// This pass computes this transformation in two main phases: an analysis +/// phase that does not mutate IR, and a transformation phase where mutation +/// happens. Each phase is handled by the `MemorySlotPromotionAnalyzer` and +/// `MemorySlotPromoter` classes respectively. /// /// The two steps of the analysis phase are the following: /// - A first step computes the list of operations that transitively use the @@ -62,36 +73,54 @@ using namespace mlir; /// the user or deleting it. Naturally, direct uses of the slot must be removed. /// Sometimes additional uses must also be removed: this is notably the case /// when a direct user of the slot cannot rewire its use and must delete itself, -/// and thus must make its users no longer use it. If any of those uses cannot -/// be removed by their users in any way, promotion cannot continue: this is -/// decided at this step. +/// and thus must make its users no longer use it. If the allocation is used in +/// nested regions, it is also ensured the region operations provide the right +/// interface to analyze the values of the allocation at the edges of its +/// regions. If any of those constraints cannot be satisfied, promotion cannot +/// continue: this is decided at this step. /// - A second step computes the list of blocks where a block argument will be /// needed ("merge points") without mutating the IR. These blocks are the blocks /// leading to a definition clash between two predecessors. Such blocks happen /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing -/// a store, as they represent the point where a clear defining dominator stops +/// a store, as they represent the points where a clear defining dominator stops /// existing. Computing this information in advance allows making sure the /// terminators that will forward values are capable of doing so (inability to /// do so aborts promotion at this step). /// -/// At this point, promotion is guaranteed to happen, and the mutation phase can -/// begin with the following steps: -/// - A third step computes the reaching definition of the memory slot at each -/// blocking user. This is the core of the mem2reg algorithm, also known as -/// load-store forwarding. This analyses loads and stores and propagates which -/// value must be stored in the slot at each blocking user. This is achieved by -/// doing a depth-first walk of the dominator tree of the function. This is -/// sufficient because the reaching definition at the beginning of a block is -/// either its new block argument if it is a merge block, or the definition -/// reaching the end of its immediate dominator (parent in the dominator tree). -/// We can therefore propagate this information down the dominator tree to -/// proceed with renaming within blocks. -/// - The final fourth step uses the reaching definition to remove blocking uses -/// in topological order. +/// At this point, promotion is guaranteed to happen, and the transformation +/// phase can begin. For each region of the program, a two step process is +/// carried out. +/// - The first step of the per-region process computes the reaching definition +/// of the memory slot at each blocking user. This is the core of the mem2reg +/// algorithm, also known as load-store forwarding. This analyses loads and +/// stores and propagates which value must be stored in the slot at each +/// blocking user. This is achieved by doing a depth-first walk of the dominator +/// tree of the function. This is sufficient because the reaching definition at +/// the beginning of a block is either its new block argument if it is a merge +/// block, or the definition reaching the end of its immediate dominator (parent +/// in the dominator tree). We can therefore propagate this information down the +/// dominator tree to proceed with renaming within blocks. If at any point a +/// region operation that contains a use of the allocation is encountered, the +/// transformation process is triggered on the child regions of the encountered +/// operation, to obtain the reaching definition at its end and carry on with +/// the value forwarding. +/// - The second step of the per-region process uses the reaching definition to +/// remove blocking uses in topological order. Some reaching definitions may +/// be values that will be removed or modified during the blocking use removal +/// step (typically, in the case of a store that stores the result of a load). +/// To properly handle such values, this step traverses the operations to modify +/// in reverse topological order. This way, if a value that will disappear is +/// used in place of reaching definition, the logic to make it disappear will be +/// executed after the value has been used to replace an operation. For regions +/// within a PromotableRegionOpInterface, in order to correctly handle cases +/// where the finalization logic would use a reaching definition that will be +/// replaced, the finalization logic must be called before the blocking use +/// removal step, so that any use of a value that will be removed gets properly +/// replaced. /// /// For further reading, chapter three of SSA-based Compiler Design [1] -/// showcases SSA construction, where mem2reg is an adaptation of the same -/// process. +/// showcases SSA construction for control-flow graphs, where mem2reg is an +/// adaptation of the same process. /// /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), /// Springer. @@ -100,18 +129,34 @@ namespace { using BlockingUsesMap = llvm::MapVector>; +using RegionBlockingUsesMap = + llvm::SmallMapVector; + +using RegionSet = SmallPtrSet; + +/// Information about regions that will be traversed for promotion, computed +/// during promotion analysis. +struct RegionPromotionInfo { + /// True if an operation storing to the slot is present in the region. + bool hasValueStores; +}; /// Information computed during promotion analysis used to perform actual /// promotion. struct MemorySlotPromotionInfo { /// Blocks for which at least two definitions of the slot values clash. SmallPtrSet mergePoints; - /// Contains, for each operation, which uses must be eliminated by promotion. - /// This is a DAG structure because if an operation must eliminate some of - /// its uses, it is because the defining ops of the blocking uses requested - /// it. The defining ops therefore must also have blocking uses or be the - /// starting point of the blocking uses. - BlockingUsesMap userToBlockingUses; + /// Contains, for each each region, the blocking uses for its operations. The + /// blocking uses are the uses that must be eliminated by promotion. For each + /// region, this is a DAG structure because if an operation must eliminate + /// some of its uses, it is because the defining ops of the blocking uses + /// requested it. The defining ops therefore must also have blocking uses or + /// be the starting point of the blocking uses. + RegionBlockingUsesMap userToBlockingUses; + /// Regions of which the edges must be analyzed for promotion. All regions + /// are guaranteed to be held by a PromotableRegionOpInterface, and to be + /// nested within the parent region of the slot pointer. + DenseMap regionsToPromote; }; /// Computes information for basic slot promotion. This will check that direct @@ -135,18 +180,20 @@ class MemorySlotPromotionAnalyzer { /// uses (typically, removing its users because it will delete itself to /// resolve its own blocking uses). This will fail if one of the transitive /// users cannot remove a requested use, and should prevent promotion. - LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses); - - /// Computes in which blocks the value stored in the slot is actually used, - /// meaning blocks leading to a load. This method uses `definingBlocks`, the - /// set of blocks containing a store to the slot (defining the value of the - /// slot). - SmallPtrSet - computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); - - /// Computes the points in which multiple re-definitions of the slot's value - /// (stores) may conflict. - void computeMergePoints(SmallPtrSetImpl &mergePoints); + /// Resulting blocking uses are grouped by region. + /// This also ensures all the uses are within promotable regions, adding + /// information about regions to be promoted to the `regionsToPromote` map. + LogicalResult computeBlockingUses( + RegionBlockingUsesMap &userToBlockingUses, + DenseMap ®ionsToPromote); + + /// Computes the points in the provided region where multiple re-definitions + /// of the slot's value (stores) may conflict. + /// `definingBlocks` is the set of blocks containing a store to the slot, + /// either directly or inherited from a nested region. + void computeMergePoints(Region *region, + SmallPtrSetImpl &definingBlocks, + SmallPtrSetImpl &mergePoints); /// Ensures predecessors of merge points can properly provide their current /// definition of the value stored in the slot to the merge point. This can @@ -155,11 +202,17 @@ class MemorySlotPromotionAnalyzer { bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); MemorySlot slot; + DominanceInfo &dominance; const DataLayout &dataLayout; }; -using BlockIndexCache = DenseMap>; +/// Maps a region to a map of blocks to their index in the region. +/// The region is identified by its entry block pointer instead of its region +/// pointer to not need to invalidate the cache when region content is moved to +/// a new region. This only supports moves of all the blocks of a region to +/// an empty region. +using BlockIndexCache = DenseMap>; /// The MemorySlotPromoter handles the state of promoting a memory slot. It /// wraps a slot and its associated allocator. This will perform the mutation of @@ -181,19 +234,39 @@ class MemorySlotPromoter { private: /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` is the value the slot should contain at the - /// beginning of the block. This method returns the reached definition at the - /// end of the block. This method must only be called at most once per block. - Value computeReachingDefInBlock(Block *block, Value reachingDef); + /// promotion, including within nested regions needing promotion. + /// `reachingDef` is the value the slot contains at the beginning of the + /// block. This member function returns the reached definition at the end of + /// the block. If the block contains a region that needs promotion, the + /// blocking uses of that region will have been removed. This member function + /// will not remove the blocking uses contained directly in the block. + /// + /// The `reachingDef` may be a null value. In that case, a lazily-created + /// default value will be used. + /// + /// This member function must only be called at most once per block. + Value promoteInBlock(Block *block, Value reachingDef); /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` corresponds to the initial value the - /// slot will contain before any write, typically a poison value. - /// This method must only be called at most once per region. - void computeReachingDefInRegion(Region *region, Value reachingDef); - - /// Removes the blocking uses of the slot, in topological order. - void removeBlockingUses(); + /// promotion, including within nested regions needing promotion, and removes + /// the blocking uses of the slot within the region. + /// `reachingDef` is the value the slot contains at the beginning of the + /// region. + /// + /// The `reachingDef` may be a null value. In that case, a lazily-created + /// default value will be used. + /// + /// This member function must only be called at most once per region. + void promoteInRegion(Region *region, Value reachingDef); + + /// Removes the blocking uses of the slot within the given region, in + /// reverse topological order. If the content of the region was moved out + /// to a different region, the new region will be processed instead. + void removeBlockingUses(Region *region); + + /// Links merge point block arguments to the terminators targeting the merge + /// point or remove the argument if it is not used. + void linkMergePoints(); /// Lazily-constructed default value representing the content of the slot when /// no store has been executed. This function may mutate IR. @@ -209,12 +282,30 @@ class MemorySlotPromoter { /// are only computed for promotable memory operations with blocking uses. DenseMap reachingDefs; DenseMap replacedValuesMap; + + /// Contains the reaching definition at the end of the blocks visited so far. + DenseMap reachingAtBlockEnd; + + /// Lists all the values that have been set by a memory operation as a + /// reaching definition at one point during the promotion. The accompanying + /// operation is the memory operation that originally stored the value. + llvm::SmallVector> replacedValues; + /// Operations to visit with the `visitReplacedValues` method at the end of + /// the promotion. + llvm::SmallVector toVisitReplacedValues; + /// Operations to be erased at the end of the promotion. + llvm::SmallVector toErase; + DominanceInfo &dominance; const DataLayout &dataLayout; MemorySlotPromotionInfo info; const Mem2RegStatistics &statistics; /// Shared cache of block indices of specific regions. + /// Cache entries must be invalidated before any addition, removal or + /// reordering of blocks in the corresponding region. + /// Cache entries are *NOT* invalidated if all the blocks of the corresponding + /// region are moved to an empty region. BlockIndexCache &blockIndexCache; }; @@ -251,16 +342,14 @@ Value MemorySlotPromoter::getOrCreateDefaultValue() { } LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( - BlockingUsesMap &userToBlockingUses) { + RegionBlockingUsesMap &userToBlockingUses, + DenseMap ®ionsToPromote) { // The promotion of an operation may require the promotion of further // operations (typically, removing operations that use an operation that must // delete itself). We thus need to start from the use of the slot pointer and // propagate further requests through the forward slice. - // Because this pass currently only supports analysing the parent region of - // the slot pointer, if a promotable memory op that needs promotion is within - // a graph region, the slot may only be used in a graph region and should - // therefore be ignored. + // Graph regions are not supported. Region *slotPtrRegion = slot.ptr.getParentRegion(); auto slotPtrRegionOp = dyn_cast(slotPtrRegion->getParentOp()); @@ -273,10 +362,15 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( // use it. for (OpOperand &use : slot.ptr.getUses()) { SmallPtrSet &blockingUses = - userToBlockingUses[use.getOwner()]; + userToBlockingUses[use.getOwner()->getParentRegion()][use.getOwner()]; blockingUses.insert(&use); } + // Regions that immediately contain a slot memory use that is not a store. + RegionSet regionsWithDirectUse; + // Regions that immediately contain a slot memory use that is a store. + RegionSet regionsWithDirectStore; + // Then, propagate the requirements for the removal of uses. The // topologically-sorted forward slice allows for all blocking uses of an // operation to have been computed before it is reached. Operations are @@ -286,8 +380,12 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( mlir::getForwardSlice(slot.ptr, &forwardSlice); for (Operation *user : forwardSlice) { // If the next operation has no blocking uses, everything is fine. - auto *it = userToBlockingUses.find(user); - if (it == userToBlockingUses.end()) + auto *blockingUsesMapIt = userToBlockingUses.find(user->getParentRegion()); + if (blockingUsesMapIt == userToBlockingUses.end()) + continue; + BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second; + auto *it = blockingUsesMap.find(user); + if (it == blockingUsesMap.end()) continue; SmallPtrSet &blockingUses = it->second; @@ -303,6 +401,14 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses, dataLayout)) return failure(); + + // Operations that interact with the slot's memory will be promoted using + // a reaching definition. Therefore, the operation must be within a region + // where the reaching definition can be computed. + if (promotable.storesTo(slot)) + regionsWithDirectStore.insert(user->getParentRegion()); + else + regionsWithDirectUse.insert(user->getParentRegion()); } else { // An operation that has blocking uses must be promoted. If it is not // promotable, promotion must fail. @@ -314,98 +420,66 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( assert(llvm::is_contained(user->getResults(), blockingUse->get())); SmallPtrSetImpl &newUserBlockingUseSet = - userToBlockingUses[blockingUse->getOwner()]; + blockingUsesMap[blockingUse->getOwner()]; newUserBlockingUseSet.insert(blockingUse); } } - // Because this pass currently only supports analysing the parent region of - // the slot pointer, if a promotable memory op that needs promotion is outside - // of this region, promotion must fail because it will be impossible to - // provide a valid `reachingDef` for it. - for (auto &[toPromote, _] : userToBlockingUses) - if (isa(toPromote) && - toPromote->getParentRegion() != slot.ptr.getParentRegion()) - return failure(); + // Finally, check that all the regions needed are promotable, and propagate + // the constraint to their parent regions. + auto visitRegions = [&](SmallVector ®ionsToPropagateFrom, + bool hasValueStores) { + while (!regionsToPropagateFrom.empty()) { + Region *region = regionsToPropagateFrom.pop_back_val(); - return success(); -} + if (region == slot.ptr.getParentRegion() || + regionsToPromote.contains(region)) + continue; -SmallPtrSet MemorySlotPromotionAnalyzer::computeSlotLiveIn( - SmallPtrSetImpl &definingBlocks) { - SmallPtrSet liveIn; - - // The worklist contains blocks in which it is known that the slot value is - // live-in. The further blocks where this value is live-in will be inferred - // from these. - SmallVector liveInWorkList; - - // Blocks with a load before any other store to the slot are the starting - // points of the analysis. The slot value is definitely live-in in those - // blocks. - SmallPtrSet visited; - for (Operation *user : slot.ptr.getUsers()) { - if (!visited.insert(user->getBlock()).second) - continue; + RegionPromotionInfo ®ionInfo = regionsToPromote[region]; + regionInfo.hasValueStores = hasValueStores; - for (Operation &op : user->getBlock()->getOperations()) { - if (auto memOp = dyn_cast(op)) { - // If this operation loads the slot, it is loading from it before - // ever writing to it, so the value is live-in in this block. - if (memOp.loadsFrom(slot)) { - liveInWorkList.push_back(user->getBlock()); - break; - } + auto promotableParentOp = + dyn_cast(region->getParentOp()); + if (!promotableParentOp) + return failure(); - // If we store to the slot, further loads will see that value. - // Because we did not meet any load before, the value is not live-in. - if (memOp.storesTo(slot)) - break; - } + if (!promotableParentOp.isRegionPromotable(slot, region, hasValueStores)) + return failure(); + + regionsToPropagateFrom.push_back(region->getParentRegion()); } - } - // The information is then propagated to the predecessors until a def site - // (store) is found. - while (!liveInWorkList.empty()) { - Block *liveInBlock = liveInWorkList.pop_back_val(); + return success(); + }; - if (!liveIn.insert(liveInBlock).second) - continue; + // Start with the regions that directly contain a store to give priority + // to stores in the propagation of `hasValueStores` information. + SmallVector regionsToPropagateFrom(regionsWithDirectStore.begin(), + regionsWithDirectStore.end()); + if (failed(visitRegions(regionsToPropagateFrom, true))) + return failure(); - // If a predecessor is a defining block, either: - // - It has a load before its first store, in which case it is live-in but - // has already been processed in the initialisation step. - // - It has a store before any load, in which case it is not live-in. - // We can thus at this stage insert to the worklist only predecessors that - // are not defining blocks. - for (Block *pred : liveInBlock->getPredecessors()) - if (!definingBlocks.contains(pred)) - liveInWorkList.push_back(pred); - } + // Then, propagate from the regions that directly contain non-store uses. + regionsToPropagateFrom.clear(); + regionsToPropagateFrom.append(regionsWithDirectUse.begin(), + regionsWithDirectUse.end()); + if (failed(visitRegions(regionsToPropagateFrom, false))) + return failure(); - return liveIn; + return success(); } using IDFCalculator = llvm::IDFCalculatorBase; void MemorySlotPromotionAnalyzer::computeMergePoints( + Region *region, SmallPtrSetImpl &definingBlocks, SmallPtrSetImpl &mergePoints) { - if (slot.ptr.getParentRegion()->hasOneBlock()) + if (region->hasOneBlock()) return; - IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion())); - - SmallPtrSet definingBlocks; - for (Operation *user : slot.ptr.getUsers()) - if (auto storeOp = dyn_cast(user)) - if (storeOp.storesTo(slot)) - definingBlocks.insert(user->getBlock()); - + IDFCalculator idfCalculator(dominance.getDomTree(region)); idfCalculator.setDefiningBlocks(definingBlocks); - SmallPtrSet liveIn = computeSlotLiveIn(definingBlocks); - idfCalculator.setLiveInBlocks(liveIn); - SmallVector mergePointsVec; idfCalculator.calculate(mergePointsVec); @@ -430,13 +504,30 @@ MemorySlotPromotionAnalyzer::computeInfo() { // promotion to happen. These operations need to resolve some of their uses, // either by rewiring them or simply deleting themselves. If any of them // cannot find a way to resolve their blocking uses, we abort the promotion. - if (failed(computeBlockingUses(info.userToBlockingUses))) + // We also compute at this stage the regions that will be analyzed for + // reaching definition information. + if (failed( + computeBlockingUses(info.userToBlockingUses, info.regionsToPromote))) return {}; + // Compute the blocks containing a store for each region, either directly or + // inherited from a nested region. As a side effect, `definingBlocks` contains + // all regions with at least one store. + DenseMap> definingBlocks; + for (Operation *user : slot.ptr.getUsers()) + if (auto storeOp = dyn_cast(user)) + if (storeOp.storesTo(slot)) + definingBlocks[user->getParentRegion()].insert(user->getBlock()); + for (auto &[region, regionInfo] : info.regionsToPromote) + if (regionInfo.hasValueStores) + definingBlocks[region->getParentRegion()].insert( + region->getParentOp()->getBlock()); + // Then, compute blocks in which two or more definitions of the allocated // variable may conflict. These blocks will need a new block argument to // accommodate this. - computeMergePoints(info.mergePoints); + for (auto &[region, defBlocks] : definingBlocks) + computeMergePoints(region, defBlocks, info.mergePoints); // The slot can be promoted if the block arguments to be created can // actually be populated with values, which may not be possible depending @@ -447,18 +538,23 @@ MemorySlotPromotionAnalyzer::computeInfo() { return info; } -Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, - Value reachingDef) { +Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) { SmallVector blockOps; for (Operation &op : block->getOperations()) blockOps.push_back(&op); for (Operation *op : blockOps) { + // Promote operations that interact with the slot's memory. if (auto memOp = dyn_cast(op)) { - if (info.userToBlockingUses.contains(memOp)) + if (info.userToBlockingUses[memOp->getParentRegion()].contains(memOp)) reachingDefs.insert({memOp, reachingDef}); if (memOp.storesTo(slot)) { builder.setInsertionPointAfter(memOp); + // To not expose default value creation to the interfaces, if we have + // no reaching definition by now, we set it to the default value. + // This is slightly too eager as `getStored` may not need it. + if (!reachingDef) + reachingDef = getOrCreateDefaultValue(); Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout); assert(stored && "a memory operation storing to a slot must provide a " "new definition of the slot"); @@ -466,16 +562,73 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, replacedValuesMap[memOp] = stored; } } + + // Promote regions that contain operations that interact with the slot's + // memory. + if (auto promotableRegionOp = dyn_cast(op)) { + bool needsPromotion = false; + bool hasValueStores = false; + for (Region ®ion : op->getRegions()) { + auto regionInfoIt = info.regionsToPromote.find(®ion); + if (regionInfoIt == info.regionsToPromote.end()) + continue; + needsPromotion = true; + if (!regionInfoIt->second.hasValueStores) + continue; + + hasValueStores = true; + break; + } + + if (needsPromotion) { + llvm::SmallMapVector regionsToProcess; + + // To not expose default value creation to the interfaces, if we have + // no reaching definition by now, we set it to the default value. + // This is slightly too eager as `setupPromotion` may not need it. + if (!reachingDef) + reachingDef = getOrCreateDefaultValue(); + + promotableRegionOp.setupPromotion(slot, reachingDef, hasValueStores, + regionsToProcess); + +#ifndef NDEBUG + for (Region ®ion : op->getRegions()) + if (info.regionsToPromote.contains(®ion)) + assert( + regionsToProcess.contains(®ion) && + "reaching definition must be provided for a required region"); +#endif // NDEBUG + + for (auto &[region, reachingDef] : regionsToProcess) { + assert(region->getParentOp() == op && + "region must be part of the operation"); + if (!info.regionsToPromote.contains(region)) + continue; + promoteInRegion(region, reachingDef); + } + + builder.setInsertionPointAfter(op); + reachingDef = promotableRegionOp.finalizePromotion( + slot, reachingDef, hasValueStores, reachingAtBlockEnd, builder); + + // Blocking uses can then be removed for the regions that were promoted. + // Even though `finalizePromotion` may have moved regions to a new + // operation, `removeBlockingUses` handles this case and will redirect + // processing to the correct region. + for (auto &[region, reachingDef] : regionsToProcess) + removeBlockingUses(region); + } + } } + reachingAtBlockEnd[block] = reachingDef; return reachingDef; } -void MemorySlotPromoter::computeReachingDefInRegion(Region *region, - Value reachingDef) { - assert(reachingDef && "expected an initial reaching def to be provided"); +void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) { if (region->hasOneBlock()) { - computeReachingDefInBlock(®ion->front(), reachingDef); + promoteInBlock(®ion->front(), reachingDef); return; } @@ -486,7 +639,7 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region, SmallVector dfsStack; - auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion()); + auto &domTree = dominance.getDomTree(region); dfsStack.emplace_back( {domTree.getNode(®ion->front()), reachingDef}); @@ -498,40 +651,28 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region, if (info.mergePoints.contains(block)) { BlockArgument blockArgument = block->addArgument(slot.elemType, slot.ptr.getLoc()); - builder.setInsertionPointToStart(block); - allocator.handleBlockArgument(slot, blockArgument, builder); job.reachingDef = blockArgument; - - if (statistics.newBlockArgumentAmount) - (*statistics.newBlockArgumentAmount)++; } - job.reachingDef = computeReachingDefInBlock(block, job.reachingDef); - assert(job.reachingDef); - - if (auto terminator = dyn_cast(block->getTerminator())) { - for (BlockOperand &blockOperand : terminator->getBlockOperands()) { - if (info.mergePoints.contains(blockOperand.get())) { - terminator.getSuccessorOperands(blockOperand.getOperandNumber()) - .append(job.reachingDef); - } - } - } + job.reachingDef = promoteInBlock(block, job.reachingDef); for (auto *child : job.block->children()) dfsStack.emplace_back({child, job.reachingDef}); } } -/// Gets or creates a block index mapping for `region`. +/// Gets or creates a block index mapping for the region of which the entry +/// block is `regionEntryBlock`. static const DenseMap & -getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) { - auto [it, inserted] = blockIndexCache.try_emplace(region); +getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, + Block *regionEntryBlock) { + auto [it, inserted] = blockIndexCache.try_emplace(regionEntryBlock); if (!inserted) return it->second; DenseMap &blockIndices = it->second; - SetVector topologicalOrder = getBlocksSortedByDominance(*region); + SetVector topologicalOrder = + getBlocksSortedByDominance(*regionEntryBlock->getParent()); for (auto [index, block] : llvm::enumerate(topologicalOrder)) blockIndices[block] = index; return blockIndices; @@ -540,12 +681,17 @@ getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) { /// Sorts `ops` according to dominance. Relies on the topological order of basic /// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the /// potentially expensive recomputation of a block index map. +/// This function assumes no blocks are ever deleted or entry block changed +/// during the lifetime of the block index cache. static void dominanceSort(SmallVector &ops, Region ®ion, BlockIndexCache &blockIndexCache) { + if (region.empty()) + return; + // Produce a topological block order and construct a map to lookup the indices // of blocks. const DenseMap &topoBlockIndices = - getOrCreateBlockIndices(blockIndexCache, ®ion); + getOrCreateBlockIndices(blockIndexCache, ®ion.front()); // Combining the topological order of the basic blocks together with block // internal operation order guarantees a deterministic, dominance respecting @@ -559,79 +705,141 @@ static void dominanceSort(SmallVector &ops, Region ®ion, }); } -void MemorySlotPromoter::removeBlockingUses() { +void MemorySlotPromoter::removeBlockingUses(Region *region) { + auto *blockingUsesMapIt = info.userToBlockingUses.find(region); + if (blockingUsesMapIt == info.userToBlockingUses.end()) + return; + BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second; + if (blockingUsesMap.empty()) + return; + + // Operations may have been moved to a different region at this point. + // To cover this, we process the current region of an operation to remove + // instead of the provided region. + region = blockingUsesMap.front().first->getParentRegion(); +#ifndef NDEBUG + for (auto &[op, blockingUses] : blockingUsesMap) + assert(op->getParentRegion() == region && + "all operations must still be in the same region"); +#endif // NDEBUG + llvm::SmallVector usersToRemoveUses( - llvm::make_first_range(info.userToBlockingUses)); + llvm::make_first_range(blockingUsesMap)); // Sort according to dominance. - dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(), - blockIndexCache); + dominanceSort(usersToRemoveUses, *region, blockIndexCache); - llvm::SmallVector toErase; - // List of all replaced values in the slot. - llvm::SmallVector> replacedValuesList; - // Ops to visit with the `visitReplacedValues` method. - llvm::SmallVector toVisit; + // Iterate over the operations to rewrite in reverse dominance order. for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) { if (auto toPromoteMemOp = dyn_cast(toPromote)) { Value reachingDef = reachingDefs.lookup(toPromoteMemOp); // If no reaching definition is known, this use is outside the reach of // the slot. The default value should thus be used. + // FIXME: This is too eager, and will generate default values even for + // pure stores. This cannot be removed easily as partial stores may + // still require a default value to complete. if (!reachingDef) reachingDef = getOrCreateDefaultValue(); builder.setInsertionPointAfter(toPromote); - if (toPromoteMemOp.removeBlockingUses( - slot, info.userToBlockingUses[toPromote], builder, reachingDef, - dataLayout) == DeletionKind::Delete) + if (toPromoteMemOp.removeBlockingUses(slot, blockingUsesMap[toPromote], + builder, reachingDef, + dataLayout) == DeletionKind::Delete) toErase.push_back(toPromote); if (toPromoteMemOp.storesTo(slot)) if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) - replacedValuesList.push_back({toPromoteMemOp, replacedValue}); + replacedValues.push_back({toPromoteMemOp, replacedValue}); continue; } auto toPromoteBasic = cast(toPromote); builder.setInsertionPointAfter(toPromote); - if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], + if (toPromoteBasic.removeBlockingUses(blockingUsesMap[toPromote], builder) == DeletionKind::Delete) toErase.push_back(toPromote); if (toPromoteBasic.requiresReplacedValues()) - toVisit.push_back(toPromoteBasic); + toVisitReplacedValues.push_back(toPromoteBasic); } - for (PromotableOpInterface op : toVisit) { - builder.setInsertionPointAfter(op); - op.visitReplacedValues(replacedValuesList, builder); +} + +void MemorySlotPromoter::linkMergePoints() { + // We want to eliminate unused block arguments. In case connecting a block + // argument to its predecessor would trigger the use of the predecessor's + // unused block argument, we need to process merge points in an expanding + // worklist, `mergePointArgsToProcess`. + + SmallPtrSet mergePointArgsUnused; + SmallVector mergePointArgsToProcess; + for (Block *mergePoint : info.mergePoints) { + BlockArgument arg = mergePoint->getArguments().back(); + if (arg.use_empty()) + mergePointArgsUnused.insert(arg); + else + mergePointArgsToProcess.push_back(arg); } - for (Operation *toEraseOp : toErase) - toEraseOp->erase(); + while (!mergePointArgsToProcess.empty()) { + BlockArgument arg = mergePointArgsToProcess.pop_back_val(); + Block *mergePoint = arg.getOwner(); - assert(slot.ptr.use_empty() && - "after promotion, the slot pointer should not be used anymore"); + for (BlockOperand &use : mergePoint->getUses()) { + Value reachingDef = reachingAtBlockEnd[use.getOwner()->getBlock()]; + if (!reachingDef) + reachingDef = getOrCreateDefaultValue(); + + // If the reaching definition is a block argument of an unused merge + // point, mark it as used and process it as such later. + auto reachingDefArgument = dyn_cast(reachingDef); + if (reachingDefArgument && + mergePointArgsUnused.erase(reachingDefArgument)) + mergePointArgsToProcess.push_back(reachingDefArgument); + + BranchOpInterface user = cast(use.getOwner()); + user.getSuccessorOperands(use.getOperandNumber()).append(reachingDef); + } + + builder.setInsertionPointToStart(mergePoint); + allocator.handleBlockArgument(slot, arg, builder); + if (statistics.newBlockArgumentAmount) + (*statistics.newBlockArgumentAmount)++; + } + + for (BlockArgument arg : mergePointArgsUnused) { + Block *mergePoint = arg.getOwner(); + mergePoint->eraseArgument(mergePoint->getNumArguments() - 1); + } } std::optional MemorySlotPromoter::promoteSlot() { - computeReachingDefInRegion(slot.ptr.getParentRegion(), - getOrCreateDefaultValue()); + // Perform the promotion recursively through nested regions. The reaching + // definition starts with a null value that will be replaced by a + // lazily-created default value if the value must be passed to a promotion + // interface while no store has been encountered yet. + // Innermost regions will see their blocking uses be removed, but not the + // outermost region which we have to remove manually afterwards. This is + // because PromotableRegionOpInterface::finalizePromotion must be called + // before removeBlockingUses. + promoteInRegion(slot.ptr.getParentRegion(), nullptr); + + // Blocking uses can then be removed for the outermost region. + removeBlockingUses(slot.ptr.getParentRegion()); + + // Notify operations that requested it of the reaching definitions set by + // storing memory operations. + for (PromotableOpInterface op : toVisitReplacedValues) { + builder.setInsertionPointAfter(op); + op.visitReplacedValues(replacedValues, builder); + } - // Now that reaching definitions are known, remove all users. - removeBlockingUses(); + // Finally, connect merge points to their predecessor's reaching definitions. + linkMergePoints(); - // Update terminators in dead branches to forward default if they are - // succeeded by a merge points. - for (Block *mergePoint : info.mergePoints) { - for (BlockOperand &use : mergePoint->getUses()) { - auto user = cast(use.getOwner()); - SuccessorOperands succOperands = - user.getSuccessorOperands(use.getOperandNumber()); - assert(succOperands.size() == mergePoint->getNumArguments() || - succOperands.size() + 1 == mergePoint->getNumArguments()); - if (succOperands.size() + 1 == mergePoint->getNumArguments()) - succOperands.append(getOrCreateDefaultValue()); - } - } + for (Operation *toEraseOp : toErase) + toEraseOp->erase(); + + assert(slot.ptr.use_empty() && + "after promotion, the slot pointer should not be used anymore"); LDBG() << "Promoted memory slot: " << slot.ptr; diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir index 716a5860a0c07..779d72000d543 100644 --- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir @@ -613,24 +613,20 @@ llvm.func @use(i64) // ----- -// This test should no longer be an issue once promotion within subregions -// is supported. // CHECK-LABEL: llvm.func @subregion_block_promotion // CHECK-SAME: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64 llvm.func @subregion_block_promotion(%arg0: i64, %arg1: i64) -> i64 { %0 = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[ALLOCA:.*]] = llvm.alloca + // CHECK-NOT: = llvm.alloca %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr - // CHECK: llvm.store %[[ARG1]], %[[ALLOCA]] llvm.store %arg1, %1 {alignment = 4 : i64} : i64, !llvm.ptr - // CHECK: scf.execute_region { + // CHECK: %[[RES:.*]] = scf.execute_region -> i64 { scf.execute_region { - // CHECK: llvm.store %[[ARG0]], %[[ALLOCA]] llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr scf.yield } + // CHECK: scf.yield %[[ARG0]] : i64 // CHECK: } - // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64 // CHECK: llvm.return %[[RES]] : i64 llvm.return %2 : i64 diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir index dd68675cc4441..30a268521c69b 100644 --- a/mlir/test/Dialect/MemRef/mem2reg.mlir +++ b/mlir/test/Dialect/MemRef/mem2reg.mlir @@ -163,3 +163,16 @@ func.func @promotable_nonpromotable_intertwined() -> i32 { } func.func @use(%arg: memref) { return } + +// ----- + +// CHECK-LABEL: func.func @unused_alloca_store_loop +func.func @unused_alloca_store_loop() { + // CHECK-NOT: memref.alloca + %cst = arith.constant 1 : i32 + %alloca = memref.alloca() : memref + cf.br ^bb1 +^bb1: + memref.store %cst, %alloca[] : memref + cf.br ^bb1 +} diff --git a/mlir/test/Dialect/SCF/mem2reg-reject.mlir b/mlir/test/Dialect/SCF/mem2reg-reject.mlir new file mode 100644 index 0000000000000..9497bf47ff09b --- /dev/null +++ b/mlir/test/Dialect/SCF/mem2reg-reject.mlir @@ -0,0 +1,160 @@ +// RUN: mlir-opt %s --mem2reg --split-input-file | FileCheck %s + +// Check that a store inside a forall prevents promotion. + +// CHECK-LABEL: func.func @forall_store +// CHECK-SAME: (%[[UB:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[C5]], %[[ALLOCA]][] +// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) { +// CHECK: memref.store %[[C7]], %[[ALLOCA]][] +// CHECK: } +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][] +// CHECK: return %[[LOAD]] : i32 +func.func @forall_store(%ub: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.forall (%i) in (%ub) { + memref.store %c7, %alloca[] : memref + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a store inside an if inside a forall prevents promotion. + +// CHECK-LABEL: func.func @forall_if_store +// CHECK-SAME: (%[[UB:.*]]: index, %[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[C5]], %[[ALLOCA]][] +// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) { +// CHECK: scf.if %[[COND]] { +// CHECK: memref.store %[[C7]], %[[ALLOCA]][] +// CHECK: } +// CHECK: } +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][] +// CHECK: return %[[LOAD]] : i32 +func.func @forall_if_store(%ub: index, %cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.forall (%i) in (%ub) { + scf.if %cond { + memref.store %c7, %alloca[] : memref + scf.yield + } + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a store inside a parallel prevents promotion. + +// CHECK-LABEL: func.func @parallel_store +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[C5]], %[[ALLOCA]][] +// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) { +// CHECK: memref.store %[[C7]], %[[ALLOCA]][] +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][] +// CHECK: return %[[LOAD]] : i32 +func.func @parallel_store(%lb: index, %ub: index, %step: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.parallel (%i) = (%lb) to (%ub) step (%step) { + memref.store %c7, %alloca[] : memref + scf.reduce + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a store inside an if inside a parallel prevents promotion. + +// CHECK-LABEL: func.func @parallel_if_store +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[C5]], %[[ALLOCA]][] +// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) { +// CHECK: scf.if %[[COND]] { +// CHECK: memref.store %[[C7]], %[[ALLOCA]][] +// CHECK: } +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][] +// CHECK: return %[[LOAD]] : i32 +func.func @parallel_if_store(%lb: index, %ub: index, %step: index, %cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.parallel (%i) = (%lb) to (%ub) step (%step) { + scf.if %cond { + memref.store %c7, %alloca[] : memref + scf.yield + } + scf.reduce + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a store inside a reduce region prevents promotion. + +// CHECK-LABEL: func.func @parallel_reduce_store +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[C5]], %[[ALLOCA]][] +// CHECK: %[[RES:.*]] = scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) init (%[[C0]]) -> i32 { +// CHECK: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK: scf.reduce(%[[C1]] : i32) { +// CHECK: ^{{.*}}(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32 +// CHECK: memref.store %[[SUM]], %[[ALLOCA]][] +// CHECK: scf.reduce.return %[[SUM]] : i32 +// CHECK: } +// CHECK: } +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][] +// CHECK: return %[[LOAD]] : i32 +func.func @parallel_reduce_store(%lb: index, %ub: index, %step: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c0 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + %res = scf.parallel (%i) = (%lb) to (%ub) step (%step) init (%c0) -> i32 { + %c1 = arith.constant 1 : i32 + scf.reduce(%c1 : i32) { + ^bb0(%lhs: i32, %rhs: i32): + %sum = arith.addi %lhs, %rhs : i32 + memref.store %sum, %alloca[] : memref + scf.reduce.return %sum : i32 + } + } + %load = memref.load %alloca[] : memref + return %load : i32 +} diff --git a/mlir/test/Dialect/SCF/mem2reg.mlir b/mlir/test/Dialect/SCF/mem2reg.mlir new file mode 100644 index 0000000000000..82098b2d306b2 --- /dev/null +++ b/mlir/test/Dialect/SCF/mem2reg.mlir @@ -0,0 +1,1120 @@ +// RUN: mlir-opt %s --mem2reg --split-input-file | FileCheck %s \ +// RUN: -implicit-check-not "memref.alloca" \ +// RUN: -implicit-check-not "memref.load" \ +// RUN: -implicit-check-not "memref.store" + +// Check regions within if are promoted. + +// CHECK-LABEL: func.func @if_load_only +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @if_load_only(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + %res = scf.if %cond -> i32 { + %load = memref.load %alloca[] : memref + scf.yield %load : i32 + } else { + scf.yield %c5 : i32 + } + return %res : i32 +} + +// ----- + +// Check load promotion through an if with no else branch. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @if_no_else_load +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.if %[[COND]] { +// CHECK: call @use(%[[C5]]) +// CHECK: } +// CHECK: call @use(%[[C5]]) +func.func @if_no_else_load(%cond: i1) { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %cond { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + %load2 = memref.load %alloca[] : memref + func.call @use(%load2) : (i32) -> () + return +} + +// ----- + +// Check store promotion through an if with no else branch. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @if_no_else_store +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (i32) +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: call @use(%[[IF]]) +func.func @if_no_else_store(%cond: i1) { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %cond { + memref.store %c7, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + return +} + +// ----- + +// Check store promotion through nested ifs with no else branches. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @if_nested_store +// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[OUTER:.*]] = scf.if %[[COND0]] -> (i32) +// CHECK: %[[INNER:.*]] = scf.if %[[COND1]] -> (i32) +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: scf.yield %[[INNER]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: call @use(%[[OUTER]]) +func.func @if_nested_store(%cond0: i1, %cond1: i1) { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %cond0 { + scf.if %cond1 { + memref.store %c7, %alloca[] : memref + scf.yield + } + scf.yield + } + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + return +} + +// ----- + +// Check that a store coming from a load of the same slot is correctly promoted. + +// CHECK-LABEL: func.func @if_load_into_store +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @if_load_into_store(%arg1 : i1) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %arg1 { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.yield + } + %loaded2 = memref.load %alloca[] : memref + return %loaded2 : i32 +} + +// ----- + +// Check promotion of a load followed by a nested if containing a store of +// the loaded value. + +// CHECK-LABEL: func.func @if_load_then_nested_if_store +// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.if %[[COND0]] -> (i32) { +// CHECK: %[[INNER:.*]] = scf.if %[[COND1]] -> (i32) { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: scf.yield %[[INNER]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @if_load_then_nested_if_store(%cond0: i1, %cond1: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %cond0 { + %loaded = memref.load %alloca[] : memref + scf.if %cond1 { + memref.store %loaded, %alloca[] : memref + scf.yield + } + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check load promotion through execute_region. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @execute_region_load +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.execute_region { +// CHECK: call @use(%[[C5]]) +// CHECK: scf.yield +// CHECK: } +func.func @execute_region_load() { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + return +} + +// ----- + +// Check store promotion through execute_region. + +// CHECK-LABEL: func.func @execute_region_store +// CHECK: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_store() -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + memref.store %c7, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check promotion through an execute_region with CFG control flow and a +// nested if containing a load. This ensures a block argument is created +// even in blocks with no direct slot use. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @execute_region_cfg +// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1) +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: cf.cond_br %[[COND0]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: cf.br ^[[BB3:.*]](%[[C7]] : i32) +// CHECK: ^[[BB2]]: +// CHECK: cf.br ^[[BB3]](%[[C9]] : i32) +// CHECK: ^[[BB3]](%[[VAL:.*]]: i32): +// CHECK: scf.if %[[COND1]] { +// CHECK: call @use(%[[VAL]]) +// CHECK: } +// CHECK: scf.yield %[[VAL]] : i32 +// CHECK: } +// CHECK: call @use(%[[RES]]) +func.func @execute_region_cfg(%cond0: i1, %cond1: i1) { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %c9 = arith.constant 9 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + cf.cond_br %cond0, ^bb1, ^bb2 + ^bb1: + memref.store %c7, %alloca[] : memref + cf.br ^bb3 + ^bb2: + memref.store %c9, %alloca[] : memref + cf.br ^bb3 + ^bb3: + scf.if %cond1 { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + scf.yield + } + %load2 = memref.load %alloca[] : memref + func.call @use(%load2) : (i32) -> () + return +} + +// CHECK-LABEL: func.func @execute_region_cfg_no_use_at_all +// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1) +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: cf.cond_br %[[COND0]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: cf.br ^[[BB3:.*]](%[[C7]] : i32) +// CHECK: ^[[BB2]]: +// CHECK: cf.br ^[[BB3]](%[[C9]] : i32) +// CHECK: ^[[BB3]](%[[VAL:.*]]: i32): +// CHECK: scf.yield %[[VAL]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_cfg_no_use_at_all(%cond0: i1, %cond1: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %c9 = arith.constant 9 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + cf.cond_br %cond0, ^bb1, ^bb2 + ^bb1: + memref.store %c7, %alloca[] : memref + cf.br ^bb3 + ^bb2: + memref.store %c9, %alloca[] : memref + cf.br ^bb3 + ^bb3: + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// CHECK-LABEL: func.func @execute_region_cfg_with_store +// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1) +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32 +// CHECK-DAG: %[[C11:.*]] = arith.constant 11 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: cf.cond_br %[[COND0]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: cf.br ^[[BB3:.*]] +// CHECK: ^[[BB2]]: +// CHECK: cf.br ^[[BB3]] +// CHECK: ^[[BB3]]: +// CHECK: scf.yield %[[C11]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_cfg_with_store(%cond0: i1, %cond1: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %c9 = arith.constant 9 : i32 + %c11 = arith.constant 11 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + cf.cond_br %cond0, ^bb1, ^bb2 + ^bb1: + memref.store %c7, %alloca[] : memref + cf.br ^bb3 + ^bb2: + memref.store %c9, %alloca[] : memref + cf.br ^bb3 + ^bb3: + memref.store %c11, %alloca[] : memref + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check promotion through an execute_region with multiple yield terminators +// having different reaching definitions. + +// CHECK-LABEL: func.func @execute_region_multiple_yields +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: ^[[BB2]]: +// CHECK: scf.yield %[[C9]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_multiple_yields(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %c9 = arith.constant 9 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + memref.store %c7, %alloca[] : memref + scf.yield + ^bb2: + memref.store %c9, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check promotion when both yield terminators share the same reaching +// definition from a store in the entry block. + +// CHECK-LABEL: func.func @execute_region_same_reaching_def +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: ^[[BB2]]: +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_same_reaching_def(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + memref.store %c7, %alloca[] : memref + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + scf.yield + ^bb2: + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a load-then-store of the same slot in the same block is promoted. + +// CHECK-LABEL: func.func @execute_region_load_into_store_same_block +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_load_into_store_same_block() -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a load-then-store of the same slot across blocks is promoted. + +// CHECK-LABEL: func.func @execute_region_load_into_store_diff_block +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.execute_region -> i32 { +// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: ^[[BB2]]: +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @execute_region_load_into_store_diff_block(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.execute_region { + %loaded = memref.load %alloca[] : memref + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + memref.store %loaded, %alloca[] : memref + scf.yield + ^bb2: + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check promotion through a for loop with a load and store in the body. + +// CHECK-LABEL: func.func @for_load_and_store +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK: %[[RES:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ARG:.*]] = %[[C5]]) -> (i32) { +// CHECK: %[[NEW:.*]] = arith.addi %[[ARG]], %[[C1]] : i32 +// CHECK: scf.yield %[[NEW]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @for_load_and_store(%lb: index, %ub: index, %step: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c1 = arith.constant 1 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.for %i = %lb to %ub step %step { + %load = memref.load %alloca[] : memref + %new = arith.addi %load, %c1 : i32 + memref.store %new, %alloca[] : memref + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check promotion adds a second iter_arg when one already exists. + +// CHECK-LABEL: func.func @for_existing_iter_arg +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: i32) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK: %[[RES:.*]]:2 = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[MUL_ARG:.*]] = %[[INIT]], %[[SLOT_ARG:.*]] = %[[C5]]) -> (i32, i32) { +// CHECK: %[[MUL:.*]] = arith.muli %[[MUL_ARG]], %[[MUL_ARG]] : i32 +// CHECK: %[[NEW:.*]] = arith.addi %[[SLOT_ARG]], %[[C1]] : i32 +// CHECK: scf.yield %[[MUL]], %[[NEW]] : i32, i32 +// CHECK: } +// CHECK: return %[[RES]]#1 : i32 +func.func @for_existing_iter_arg(%lb: index, %ub: index, %step: index, %init: i32) -> i32 { + %c5 = arith.constant 5 : i32 + %c1 = arith.constant 1 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + %mul_res = scf.for %i = %lb to %ub step %step iter_args(%mul_arg = %init) -> i32 { + %mul = arith.muli %mul_arg, %mul_arg : i32 + %load = memref.load %alloca[] : memref + %new = arith.addi %load, %c1 : i32 + memref.store %new, %alloca[] : memref + scf.yield %mul : i32 + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check load-only promotion through a for loop generates no iter_arg. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @for_load_only +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: call @use(%[[C5]]) +// CHECK: } +// CHECK: return %[[C5]] : i32 +func.func @for_load_only(%lb: index, %ub: index, %step: index) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.for %i = %lb to %ub step %step { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check promotion through a for loop with a store inside an if in the body. + +// CHECK-LABEL: func.func @for_if_store +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[COND:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ARG:.*]] = %[[C5]]) -> (i32) { +// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (i32) { +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[ARG]] : i32 +// CHECK: } +// CHECK: scf.yield %[[IF]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @for_if_store(%lb: index, %ub: index, %step: index, %cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.for %i = %lb to %ub step %step { + scf.if %cond { + memref.store %c7, %alloca[] : memref + scf.yield + } + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a load-then-store of the same slot in a for loop is promoted. + +// CHECK-LABEL: func.func @for_load_into_store +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ARG:.*]] = %[[C5]]) -> (i32) { +// CHECK: scf.yield %[[ARG]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @for_load_into_store(%lb: index, %ub: index, %step: index) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.for %i = %lb to %ub step %step { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check load promotion through a forall. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @forall_load +// CHECK-SAME: (%[[UB:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) { +// CHECK: call @use(%[[C5]]) +// CHECK: } +func.func @forall_load(%ub: index) { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.forall (%i) in (%ub) { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + } + return +} + +// ----- + +// Check promotion through a forall nested inside an if with a store. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @forall_in_if +// CHECK-SAME: (%[[UB:.*]]: index, %[[COND:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) { +// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) { +// CHECK: call @use(%[[C7]]) +// CHECK: } +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @forall_in_if(%ub: index, %cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %cond { + memref.store %c7, %alloca[] : memref + scf.forall (%i) in (%ub) { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + } + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check load promotion through a parallel. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @parallel_load +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) { +// CHECK: call @use(%[[C5]]) +// CHECK: scf.reduce +// CHECK: } +func.func @parallel_load(%lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.parallel (%i) = (%lb) to (%ub) step (%step) { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.reduce + } + return +} + +// ----- + +// Check promotion through a parallel nested inside an if with a store. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @parallel_in_if +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[COND:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) { +// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) { +// CHECK: call @use(%[[C7]]) +// CHECK: scf.reduce +// CHECK: } +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @parallel_in_if(%lb: index, %ub: index, %step: index, %cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.if %cond { + memref.store %c7, %alloca[] : memref + scf.parallel (%i) = (%lb) to (%ub) step (%step) { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.reduce + } + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check load promotion inside a reduce region. + +// CHECK-LABEL: func.func @parallel_reduce_load +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: %[[RES:.*]] = scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) init (%[[C0]]) -> i32 { +// CHECK: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK: scf.reduce(%[[C1]] : i32) { +// CHECK: ^{{.*}}(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): +// CHECK: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32 +// CHECK: %[[MUL:.*]] = arith.muli %[[SUM]], %[[C5]] : i32 +// CHECK: scf.reduce.return %[[MUL]] : i32 +// CHECK: } +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @parallel_reduce_load(%lb: index, %ub: index, %step: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c0 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + %res = scf.parallel (%i) = (%lb) to (%ub) step (%step) init (%c0) -> i32 { + %c1 = arith.constant 1 : i32 + scf.reduce(%c1 : i32) { + ^bb0(%lhs: i32, %rhs: i32): + %sum = arith.addi %lhs, %rhs : i32 + %load = memref.load %alloca[] : memref + %mul = arith.muli %sum, %load : i32 + scf.reduce.return %mul : i32 + } + } + return %res : i32 +} + +// ----- + +// Check load promotion in the before region of a while. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @while_load_before +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.while : () -> () { +// CHECK: call @use(%[[C5]]) +// CHECK: scf.condition(%[[COND]]) +// CHECK: } do { +// CHECK: scf.yield +// CHECK: } +func.func @while_load_before(%cond: i1) { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.while : () -> () { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.condition(%cond) + } do { + scf.yield + } + return +} + +// ----- + +// Check load promotion in the after region of a while. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @while_load_after +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.while : () -> () { +// CHECK: scf.condition(%[[COND]]) +// CHECK: } do { +// CHECK: call @use(%[[C5]]) +// CHECK: scf.yield +// CHECK: } +func.func @while_load_after(%cond: i1) { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.while : () -> () { + scf.condition(%cond) + } do { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + return +} + +// ----- + +// Check promotion with a store in the before region and a load in the after. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @while_store_before_load_after +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 { +// CHECK: scf.condition(%[[COND]]) %[[C7]] : i32 +// CHECK: } do { +// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32): +// CHECK: call @use(%[[AFTER]]) +// CHECK: scf.yield %[[AFTER]] : i32 +// CHECK: } +func.func @while_store_before_load_after(%cond: i1) { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.while : () -> () { + memref.store %c7, %alloca[] : memref + scf.condition(%cond) + } do { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + return +} + +// ----- + +// Check promotion with a store in the before region and a load after the loop. + +// CHECK-LABEL: func.func @while_store_before_load_after_loop +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 { +// CHECK: scf.condition(%[[COND]]) %[[C7]] : i32 +// CHECK: } do { +// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32): +// CHECK: scf.yield %[[AFTER]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @while_store_before_load_after_loop(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.while : () -> () { + memref.store %c7, %alloca[] : memref + scf.condition(%cond) + } do { + scf.yield + } + %res = memref.load %alloca[] : memref + return %res : i32 +} + +// ----- + +// Check store promotion through a while implementing a for loop from 0 to 10. + +// CHECK-LABEL: func.func @while_store +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : i32 +// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C0]]) : (i32) -> i32 { +// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[BEFORE]], %[[C10]] : i32 +// CHECK: scf.condition(%[[COND]]) %[[BEFORE]] : i32 +// CHECK: } do { +// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32): +// CHECK: %[[NEW:.*]] = arith.addi %[[AFTER]], %[[C1]] : i32 +// CHECK: scf.yield %[[NEW]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @while_store() -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c10 = arith.constant 10 : i32 + %alloca = memref.alloca() : memref + memref.store %c0, %alloca[] : memref + scf.while : () -> () { + %val = memref.load %alloca[] : memref + %cond = arith.cmpi slt, %val, %c10 : i32 + scf.condition(%cond) + } do { + %val = memref.load %alloca[] : memref + %new = arith.addi %val, %c1 : i32 + memref.store %new, %alloca[] : memref + scf.yield + } + %res = memref.load %alloca[] : memref + return %res : i32 +} + +// ----- + +// Check that a load-then-store in the before region of a while is promoted. + +// CHECK-LABEL: func.func @while_load_into_store_before +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 { +// CHECK: scf.condition(%[[COND]]) %[[BEFORE]] : i32 +// CHECK: } do { +// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32): +// CHECK: scf.yield %[[AFTER]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @while_load_into_store_before(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.while : () -> () { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.condition(%cond) + } do { + scf.yield + } + %res = memref.load %alloca[] : memref + return %res : i32 +} + +// ----- + +// Check that a load-then-store in the after region of a while is promoted. + +// CHECK-LABEL: func.func @while_load_into_store +// CHECK-SAME: (%[[COND:.*]]: i1) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 { +// CHECK: scf.condition(%[[COND]]) %[[BEFORE]] : i32 +// CHECK: } do { +// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32): +// CHECK: scf.yield %[[AFTER]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @while_load_into_store(%cond: i1) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.while : () -> () { + scf.condition(%cond) + } do { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.yield + } + %res = memref.load %alloca[] : memref + return %res : i32 +} + +// ----- + +// Check load promotion through an index_switch default branch. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @index_switch_load_default +// CHECK-SAME: (%[[IDX:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: scf.index_switch %[[IDX]] +// CHECK: default { +// CHECK: call @use(%[[C5]]) +// CHECK: } +func.func @index_switch_load_default(%idx: index) { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.index_switch %idx + default { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + return +} + +// ----- + +// Check store promotion through an index_switch default branch. + +// CHECK-LABEL: func.func @index_switch_store_default +// CHECK-SAME: (%[[IDX:.*]]: index) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.index_switch %[[IDX]] -> i32 +// CHECK: default { +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @index_switch_store_default(%idx: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.index_switch %idx + default { + memref.store %c7, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check promotion with a store in a case and a load in the default branch. + +func.func private @use(i32) + +// CHECK-LABEL: func.func @index_switch_store_case_load_default +// CHECK-SAME: (%[[IDX:.*]]: index) +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32 +// CHECK: %[[RES:.*]] = scf.index_switch %[[IDX]] -> i32 +// CHECK: case 0 { +// CHECK: scf.yield %[[C7]] : i32 +// CHECK: } +// CHECK: default { +// CHECK: call @use(%[[C5]]) +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @index_switch_store_case_load_default(%idx: index) -> i32 { + %c5 = arith.constant 5 : i32 + %c7 = arith.constant 7 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.index_switch %idx + case 0 { + memref.store %c7, %alloca[] : memref + scf.yield + } + default { + %load = memref.load %alloca[] : memref + func.call @use(%load) : (i32) -> () + scf.yield + } + %load2 = memref.load %alloca[] : memref + return %load2 : i32 +} + +// ----- + +// Check that load-then-store of the same slot in an index_switch is promoted. + +// CHECK-LABEL: func.func @index_switch_load_into_store +// CHECK-SAME: (%[[IDX:.*]]: index) +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[RES:.*]] = scf.index_switch %[[IDX]] -> i32 +// CHECK: case 0 { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: default { +// CHECK: scf.yield %[[C5]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +func.func @index_switch_load_into_store(%idx: index) -> i32 { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + scf.index_switch %idx + case 0 { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.yield + } + default { + %loaded = memref.load %alloca[] : memref + memref.store %loaded, %alloca[] : memref + scf.yield + } + %load = memref.load %alloca[] : memref + return %load : i32 +} diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir index 4b27f3305e89d..70fbddcb25b2a 100644 --- a/mlir/test/Transforms/mem2reg.mlir +++ b/mlir/test/Transforms/mem2reg.mlir @@ -39,3 +39,77 @@ test.isolated_graph_region { %a = memref.load %slot[] : memref "test.foo"() : () -> () } + +// ----- + +// Verifies that block arguments of merge points are not abusively treated as +// the newly created block arguments. Here, ^merge has a pre-existing block +// argument (%genuine) and mem2reg adds a second one for the promoted slot. The +// slot arg then serves as the reaching definition for the follow-up merge point +// ^final. If the unused merge point propagation logic identified merge points +// by block rather than by specific block argument, it would confuse %genuine +// for the slot argument to be removed and thus not eliminate the slot which +// is unused. In other words, the genuine block argument, which is used, would +// mask that the actual slot argument is unused. + +// CHECK-LABEL: func.func @merge_point_arg_not_confused +// CHECK-SAME: (%[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32) -> i32 +// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: ^[[BB1]]: +// CHECK: cf.br ^[[MERGE:.*]](%[[A]] : i32) +// CHECK: ^[[BB2]]: +// CHECK: cf.br ^[[MERGE]](%[[B]] : i32) +// CHECK: ^[[MERGE]](%[[GENUINE:.*]]: i32): +// CHECK: cf.cond_br %[[COND]], ^[[BB3:.*]], ^[[BB4:.*]] +// CHECK: ^[[BB3]]: +// CHECK: cf.br ^[[FINAL:.*]](%[[GENUINE]] : i32) +// CHECK: ^[[BB4]]: +// CHECK: %[[DUMMY:.*]] = arith.constant 0 : i32 +// CHECK: cf.br ^[[FINAL]](%[[DUMMY]] : i32) +// CHECK: ^[[FINAL]](%[[FINAL_SLOT:.*]]: i32): +// CHECK: return %[[FINAL_SLOT]] : i32 +func.func @merge_point_arg_not_confused(%cond: i1, %a: i32, %b: i32) -> i32 { + %alloca = memref.alloca() : memref + memref.store %a, %alloca[] : memref + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + memref.store %b, %alloca[] : memref + cf.br ^merge(%a : i32) +^bb2: + cf.br ^merge(%b : i32) +^merge(%genuine: i32): + cf.cond_br %cond, ^bb3, ^bb4 +^bb3: + memref.store %genuine, %alloca[] : memref + cf.br ^final +^bb4: + %dummy = arith.constant 0 : i32 + memref.store %dummy, %alloca[] : memref + cf.br ^final +^final: + %load = memref.load %alloca[] : memref + return %load : i32 +} + +// ----- + +// Check that a load inside an unknown region-bearing op prevents promotion. + +// CHECK-LABEL: func.func @unknown_region_op_load +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[C5]], %[[ALLOCA]][] +// CHECK: "test.one_region_op"() ({ +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][] +// CHECK: "test.finish"() : () -> () +// CHECK: }) : () -> () +func.func @unknown_region_op_load() { + %c5 = arith.constant 5 : i32 + %alloca = memref.alloca() : memref + memref.store %c5, %alloca[] : memref + "test.one_region_op"() ({ + %load = memref.load %alloca[] : memref + "test.finish"() : () -> () + }) : () -> () + return +}