diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 1712af1d1eba7..149683d30d4b0 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions( llvm::SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(mlir::RegionSuccessor(getResults())); + regions.push_back(mlir::RegionSuccessor(getOperation(), getResults())); return; } @@ -4494,7 +4494,8 @@ void fir::IfOp::getSuccessorRegions( // Don't consider the else region if it is empty. mlir::Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(mlir::RegionSuccessor()); + regions.push_back( + mlir::RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(mlir::RegionSuccessor(elseRegion)); } @@ -4513,7 +4514,7 @@ void fir::IfOp::getEntrySuccessorRegions( if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } } diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h index 8bcfe51ad7cd1..3c87c453a4cf0 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { /// itself. virtual void visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionBranchPoint regionTo, const AbstractDenseLattice &after, + RegionSuccessor regionTo, const AbstractDenseLattice &after, AbstractDenseLattice *before) { meet(before, after); } @@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis /// and "to" regions. virtual void visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) { + RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) { AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( branch, regionFrom, regionTo, after, before); } @@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis } void visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionForm, - RegionBranchPoint regionTo, const AbstractDenseLattice &after, + RegionSuccessor regionTo, const AbstractDenseLattice &after, AbstractDenseLattice *before) final { visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo, static_cast(after), diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 1a33ecf8b5aa9..985573476ab78 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { /// and propagating therefrom. virtual void visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch, - RegionBranchPoint successor, + RegionSuccessor successor, ArrayRef lattices); }; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index fadd3fc10bfc4..48690151caf01 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [ /// Returns true if the mapping specified for this forall op is linear. bool usesLinearMapping(); + + /// RegionBranchOpInterface + + OperandRange getEntrySuccessorOperands(RegionSuccessor successor) { + return getInits(); + } + }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 62e66b3dabee8..ed69287410509 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" def AlternativesOp : TransformDialectOp<"alternatives", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + "getEntrySuccessorOperands"]>, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> ]> { let summary = "Executes the body for each element of the payload"; @@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select", def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods, MatchOpInterface, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td index d095659fc4838..4079848fd203a 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td @@ -63,7 +63,7 @@ def KnobOp : Op, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 7ff718ad7f241..a0a99f4953822 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -29,6 +29,7 @@ class MLIRContext; class Operation; class OperationName; class OpPrintingFlags; +class OpWithFlags; class Type; class Value; @@ -199,6 +200,7 @@ class Diagnostic { /// Stream in an Operation. Diagnostic &operator<<(Operation &op); + Diagnostic &operator<<(OpWithFlags op); Diagnostic &operator<<(Operation *op) { return *this << *op; } /// Append an operation with the given printing flags. Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 5569392cf0b41..b2019574a820d 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -1114,6 +1114,7 @@ class OpWithFlags { : op(op), theFlags(flags) {} OpPrintingFlags &flags() { return theFlags; } const OpPrintingFlags &flags() const { return theFlags; } + Operation *getOperation() const { return op; } private: Operation *op; diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 1fcb316750230..53d461df98710 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -379,6 +379,8 @@ class RegionRange friend RangeBaseT; }; +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion); + } // namespace mlir #endif // MLIR_IR_REGION_H diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index d63800c12d132..47afd252c6d68 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -15,10 +15,16 @@ #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/raw_ostream.h" namespace mlir { class BranchOpInterface; class RegionBranchOpInterface; +class RegionBranchTerminatorOpInterface; /// This class models how operands are forwarded to block arguments in control /// flow. It consists of a number, denoting how many of the successors block @@ -186,27 +192,40 @@ class RegionSuccessor { public: /// Initialize a successor that branches to another region of the parent /// operation. + /// TODO: the default value for the regionInputs is somehow broken. + /// A region successor should have its input correctly set. RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {}) - : region(region), inputs(regionInputs) {} + : successor(region), inputs(regionInputs) { + assert(region && "Region must not be null"); + } /// Initialize a successor that branches back to/out of the parent operation. - RegionSuccessor(Operation::result_range results) - : inputs(ValueRange(results)) {} - /// Constructor with no arguments. - RegionSuccessor() : inputs(ValueRange()) {} + /// The target must be one of the recursive parent operations. + RegionSuccessor(Operation *successorOp, Operation::result_range results) + : successor(successorOp), inputs(ValueRange(results)) { + assert(successorOp && "Successor op must not be null"); + } /// Return the given region successor. Returns nullptr if the successor is the /// parent operation. - Region *getSuccessor() const { return region; } + Region *getSuccessor() const { return dyn_cast(successor); } /// Return true if the successor is the parent operation. - bool isParent() const { return region == nullptr; } + bool isParent() const { return isa(successor); } /// Return the inputs to the successor that are remapped by the exit values of /// the current region. ValueRange getSuccessorInputs() const { return inputs; } + bool operator==(RegionSuccessor rhs) const { + return successor == rhs.successor && inputs == rhs.inputs; + } + + friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) { + return !(lhs == rhs); + } + private: - Region *region{nullptr}; + llvm::PointerUnion successor{nullptr}; ValueRange inputs; }; @@ -214,64 +233,67 @@ class RegionSuccessor { /// `RegionBranchOpInterface`. /// One can branch from one of two kinds of places: /// * The parent operation (aka the `RegionBranchOpInterface` implementation) -/// * A region within the parent operation. +/// * A RegionBranchTerminatorOpInterface inside a region within the parent +// operation. class RegionBranchPoint { public: /// Returns an instance of `RegionBranchPoint` representing the parent /// operation. static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); } - /// Creates a `RegionBranchPoint` that branches from the given region. - /// The pointer must not be null. - RegionBranchPoint(Region *region) : maybeRegion(region) { - assert(region && "Region must not be null"); - } - - RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {} + /// Creates a `RegionBranchPoint` that branches from the given terminator. + inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor); /// Explicitly stops users from constructing with `nullptr`. RegionBranchPoint(std::nullptr_t) = delete; - /// Constructs a `RegionBranchPoint` from the the target of a - /// `RegionSuccessor` instance. - RegionBranchPoint(RegionSuccessor successor) { - if (successor.isParent()) - maybeRegion = nullptr; - else - maybeRegion = successor.getSuccessor(); - } - - /// Assigns a region being branched from. - RegionBranchPoint &operator=(Region ®ion) { - maybeRegion = ®ion; - return *this; - } - /// Returns true if branching from the parent op. - bool isParent() const { return maybeRegion == nullptr; } + bool isParent() const { return predecessor == nullptr; } - /// Returns the region if branching from a region. + /// Returns the terminator if branching from a region. /// A null pointer otherwise. - Region *getRegionOrNull() const { return maybeRegion; } + Operation *getTerminatorPredecessorOrNull() const { return predecessor; } /// Returns true if the two branch points are equal. friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) { - return lhs.maybeRegion == rhs.maybeRegion; + return lhs.predecessor == rhs.predecessor; } private: // Private constructor to encourage the use of `RegionBranchPoint::parent`. - constexpr RegionBranchPoint() : maybeRegion(nullptr) {} + constexpr RegionBranchPoint() = default; /// Internal encoding. Uses nullptr for representing branching from the parent - /// op and the region being branched from otherwise. - Region *maybeRegion; + /// op and the region terminator being branched from otherwise. + Operation *predecessor = nullptr; }; inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) { return !(lhs == rhs); } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + RegionBranchPoint point) { + if (point.isParent()) + return os << ""; + return os << "getParentRegion() + ->getRegionNumber() + << ", terminator " + << OpWithFlags(point.getTerminatorPredecessorOrNull(), + OpPrintingFlags().skipRegions()) + << ">"; +} + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + RegionSuccessor successor) { + if (successor.isParent()) + return os << ""; + return os << "getRegionNumber() + << " with " << successor.getSuccessorInputs().size() << " inputs>"; +} + /// This class represents upper and lower bounds on the number of times a region /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least /// zero, but the upper bound may not be known. @@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase { /// Include the generated interface declarations. #include "mlir/Interfaces/ControlFlowInterfaces.h.inc" +namespace mlir { +inline RegionBranchPoint::RegionBranchPoint( + RegionBranchTerminatorOpInterface predecessor) + : predecessor(predecessor.getOperation()) {} +} // namespace mlir + #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index b8d08cc553caa..94242e3ba39ce 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { let description = [{ - This interface provides information for region operations that exhibit + This interface provides information for region-holding operations that exhibit branching behavior between held regions. I.e., this interface allows for expressing control flow information for region holding operations. @@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { be side-effect free. A "region branch point" indicates a point from which a branch originates. It - can indicate either a region of this op or `RegionBranchPoint::parent()`. In - the latter case, the branch originates from outside of the op, i.e., when - first executing this op. + can indicate either a terminator in any of the immediately nested region of + this op or `RegionBranchPoint::parent()`. In the latter case, the branch + originates from outside of the op, i.e., when first executing this op. A "region successor" indicates the target of a branch. It can indicate - either a region of this op or this op. In the former case, the region + either a region of this op or this op itself. In the former case, the region successor is a region pointer and a range of block arguments to which the "successor operands" are forwarded to. In the latter case, the control flow leaves this op and the region successor is a range of results of this op to @@ -151,10 +151,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { } ``` - `scf.for` has one region. The region has two region successors: the region - itself and the `scf.for` op. %b is an entry successor operand. %c is a - successor operand. %a is a successor block argument. %r is a successor - result. + `scf.for` has one region. The `scf.yield` has two region successors: the + region body itself and the `scf.for` op. `%b` is an entry successor + operand. `%c` is a successor operand. `%a` is a successor block argument. + `%r` is a successor result. }]; let cppNamespace = "::mlir"; @@ -162,16 +162,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { InterfaceMethod<[{ Returns the operands of this operation that are forwarded to the region successor's block arguments or this operation's results when branching - to `point`. `point` is guaranteed to be among the successors that are + to `successor`. `successor` is guaranteed to be among the successors that are returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`. Example: In the above example, this method returns the operand %b of the - `scf.for` op, regardless of the value of `point`. I.e., this op always + `scf.for` op, regardless of the value of `successor`. I.e., this op always forwards the same operands, regardless of whether the loop has 0 or more iterations. }], "::mlir::OperandRange", "getEntrySuccessorOperands", - (ins "::mlir::RegionBranchPoint":$point), [{}], + (ins "::mlir::RegionSuccessor":$successor), [{}], /*defaultImplementation=*/[{ auto operandEnd = this->getOperation()->operand_end(); return ::mlir::OperandRange(operandEnd, operandEnd); @@ -224,6 +224,80 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { (ins "::mlir::RegionBranchPoint":$point, "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions) >, + InterfaceMethod<[{ + Returns the potential region successors when branching from any + terminator in `region`. + These are the regions that may be selected during the flow of control. + }], + "void", "getSuccessorRegions", + (ins "::mlir::Region&":$region, + "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), + [{}], + /*defaultImplementation=*/[{ + for (::mlir::Block &block : region) { + if (block.empty()) + continue; + if (auto terminator = + dyn_cast(block.back())) + $_op.getSuccessorRegions(RegionBranchPoint(terminator), + regions); + } + }]>, + InterfaceMethod<[{ + Returns the potential branching point (predecessors) for a given successor. + }], + "void", "getPredecessors", + (ins "::mlir::RegionSuccessor":$successor, + "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors), + [{}], + /*defaultImplementation=*/[{ + ::llvm::SmallVector<::mlir::RegionSuccessor> successors; + $_op.getSuccessorRegions(RegionBranchPoint::parent(), + successors); + if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) { + return succ.getSuccessor() == successor.getSuccessor() || + (succ.isParent() && successor.isParent()); + })) + predecessors.push_back(RegionBranchPoint::parent()); + for (Region ®ion : $_op->getRegions()) { + for (::mlir::Block &block : region) { + if (block.empty()) + continue; + if (auto terminator = + dyn_cast(block.back())) { + ::llvm::SmallVector<::mlir::RegionSuccessor> successors; + $_op.getSuccessorRegions(RegionBranchPoint(terminator), + successors); + if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) { + return succ.getSuccessor() == successor.getSuccessor() || + (succ.isParent() && successor.isParent()); + })) + predecessors.push_back(terminator); + } + } + } + }]>, + InterfaceMethod<[{ + Returns the potential values across all (predecessors) for a given successor + input, modeled by its index (its position in the list of values). + }], + "void", "getPredecessorValues", + (ins "::mlir::RegionSuccessor":$successor, + "int":$index, + "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues), + [{}], + /*defaultImplementation=*/[{ + ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors; + $_op.getPredecessors(successor, predecessors); + for (auto predecessor : predecessors) { + if (predecessor.isParent()) { + predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]); + continue; + } + auto terminator = cast(predecessor.getTerminatorPredecessorOrNull()); + predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]); + } + }]>, InterfaceMethod<[{ Populates `invocationBounds` with the minimum and maximum number of times this operation will invoke the attached regions (assuming the @@ -298,7 +372,7 @@ def RegionBranchTerminatorOpInterface : passing them to the region successor indicated by `point`. }], "::mlir::MutableOperandRange", "getMutableSuccessorOperands", - (ins "::mlir::RegionBranchPoint":$point) + (ins "::mlir::RegionSuccessor":$point) >, InterfaceMethod<[{ Returns the potential region successors that are branched to after this @@ -317,7 +391,7 @@ def RegionBranchTerminatorOpInterface : /*defaultImplementation=*/[{ ::mlir::Operation *op = $_op; ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp()) - .getSuccessorRegions(op->getParentRegion(), regions); + .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions); }] >, ]; @@ -337,8 +411,8 @@ def RegionBranchTerminatorOpInterface : // them to the region successor given by `index`. If `index` is None, this // function returns the operands that are passed as a result to the parent // operation. - ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) { - return getMutableSuccessorOperands(point); + ::mlir::OperandRange getSuccessorOperands(::mlir::RegionSuccessor successor) { + return getMutableSuccessorOperands(successor); } }]; } @@ -504,7 +578,7 @@ def ReturnLike : TraitList<[ /*extraOpDeclaration=*/"", /*extraOpDefinition=*/[{ ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands( - ::mlir::RegionBranchPoint point) { + ::mlir::RegionSuccessor successor) { return ::mlir::MutableOperandRange(*this); } }] diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index a84d10d5d609d..24cb123e51877 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -16,19 +16,21 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/DebugLog.h" #include #include #include using namespace mlir; +#define DEBUG_TYPE "local-alias-analysis" + //===----------------------------------------------------------------------===// // Underlying Address Computation //===----------------------------------------------------------------------===// @@ -42,81 +44,47 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output); -/// Given a successor (`region`) of a RegionBranchOpInterface, collect all of -/// the underlying values being addressed by one of the successor inputs. If the -/// provided `region` is null, as per `RegionBranchOpInterface` this represents -/// the parent operation. -static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, - Region *region, Value inputValue, - unsigned inputIndex, - unsigned maxDepth, - DenseSet &visited, - SmallVectorImpl &output) { - // Given the index of a region of the branch (`predIndex`), or std::nullopt to - // represent the parent operation, try to return the index into the outputs of - // this region predecessor that correspond to the input values of `region`. If - // an index could not be found, std::nullopt is returned instead. - auto getOperandIndexIfPred = - [&](RegionBranchPoint pred) -> std::optional { - SmallVector successors; - branch.getSuccessorRegions(pred, successors); - for (RegionSuccessor &successor : successors) { - if (successor.getSuccessor() != region) - continue; - // Check that the successor inputs map to the given input value. - ValueRange inputs = successor.getSuccessorInputs(); - if (inputs.empty()) { - output.push_back(inputValue); - break; - } - unsigned firstInputIndex, lastInputIndex; - if (region) { - firstInputIndex = cast(inputs[0]).getArgNumber(); - lastInputIndex = cast(inputs.back()).getArgNumber(); - } else { - firstInputIndex = cast(inputs[0]).getResultNumber(); - lastInputIndex = cast(inputs.back()).getResultNumber(); - } - if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { - output.push_back(inputValue); - break; - } - return inputIndex - firstInputIndex; - } - return std::nullopt; - }; - - // Check branches from the parent operation. - auto branchPoint = RegionBranchPoint::parent(); - if (region) - branchPoint = region; - - if (std::optional operandIndex = - getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) { - collectUnderlyingAddressValues( - branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth, - visited, output); +/// Given a RegionBranchOpInterface operation (`branch`), a Value`inputValue` +/// which is an input for the provided successor (`initialSuccessor`), try to +/// find the possible sources for the value along the control flow edges. +static void collectUnderlyingAddressValues2( + RegionBranchOpInterface branch, RegionSuccessor initialSuccessor, + Value inputValue, unsigned inputIndex, unsigned maxDepth, + DenseSet &visited, SmallVectorImpl &output) { + LDBG() << "collectUnderlyingAddressValues2: " + << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); + LDBG() << " with initialSuccessor " << initialSuccessor; + LDBG() << " inputValue: " << inputValue; + LDBG() << " inputIndex: " << inputIndex; + LDBG() << " maxDepth: " << maxDepth; + ValueRange inputs = initialSuccessor.getSuccessorInputs(); + if (inputs.empty()) { + LDBG() << " input is empty, enqueue value"; + output.push_back(inputValue); + return; } - // Check branches from each child region. - Operation *op = branch.getOperation(); - for (Region ®ion : op->getRegions()) { - if (std::optional operandIndex = getOperandIndexIfPred(region)) { - for (Block &block : region) { - // Try to determine possible region-branch successor operands for the - // current region. - if (auto term = dyn_cast( - block.getTerminator())) { - collectUnderlyingAddressValues( - term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth, - visited, output); - } else if (block.getNumSuccessors()) { - // Otherwise, if this terminator may exit the region we can't make - // any assumptions about which values get passed. - output.push_back(inputValue); - return; - } - } - } + unsigned firstInputIndex, lastInputIndex; + if (isa(inputs[0])) { + firstInputIndex = cast(inputs[0]).getArgNumber(); + lastInputIndex = cast(inputs.back()).getArgNumber(); + } else { + firstInputIndex = cast(inputs[0]).getResultNumber(); + lastInputIndex = cast(inputs.back()).getResultNumber(); + } + if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { + LDBG() << " !! Input index " << inputIndex << " out of range " + << firstInputIndex << " to " << lastInputIndex + << ", adding input value to output"; + output.push_back(inputValue); + return; + } + SmallVector predecessorValues; + branch.getPredecessorValues(initialSuccessor, inputIndex - firstInputIndex, + predecessorValues); + LDBG() << " Found " << predecessorValues.size() << " predecessor values"; + for (Value predecessorValue : predecessorValues) { + LDBG() << " Processing predecessor value: " << predecessorValue; + collectUnderlyingAddressValues(predecessorValue, maxDepth, visited, output); } } @@ -124,22 +92,28 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output) { + LDBG() << "collectUnderlyingAddressValues (OpResult): " << result; + LDBG() << " maxDepth: " << maxDepth; + Operation *op = result.getOwner(); // If this is a view, unwrap to the source. if (ViewLikeOpInterface view = dyn_cast(op)) { if (result == view.getViewDest()) { + LDBG() << " Unwrapping view to source: " << view.getViewSource(); return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, visited, output); } } // Check to see if we can reason about the control flow of this op. if (auto branch = dyn_cast(op)) { - return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result, - result.getResultNumber(), maxDepth, - visited, output); + LDBG() << " Processing region branch operation"; + return collectUnderlyingAddressValues2( + branch, RegionSuccessor(op, op->getResults()), result, + result.getResultNumber(), maxDepth, visited, output); } + LDBG() << " Adding result to output: " << result; output.push_back(result); } @@ -148,14 +122,23 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output) { + LDBG() << "collectUnderlyingAddressValues (BlockArgument): " << arg; + LDBG() << " maxDepth: " << maxDepth; + LDBG() << " argNumber: " << arg.getArgNumber(); + LDBG() << " isEntryBlock: " << arg.getOwner()->isEntryBlock(); + Block *block = arg.getOwner(); unsigned argNumber = arg.getArgNumber(); // Handle the case of a non-entry block. if (!block->isEntryBlock()) { + LDBG() << " Processing non-entry block with " + << std::distance(block->pred_begin(), block->pred_end()) + << " predecessors"; for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { auto branch = dyn_cast((*it)->getTerminator()); if (!branch) { + LDBG() << " Cannot analyze control flow, adding argument to output"; // We can't analyze the control flow, so bail out early. output.push_back(arg); return; @@ -165,10 +148,12 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, unsigned index = it.getSuccessorIndex(); Value operand = branch.getSuccessorOperands(index)[argNumber]; if (!operand) { + LDBG() << " No operand found for argument, adding to output"; // We can't analyze the control flow, so bail out early. output.push_back(arg); return; } + LDBG() << " Processing operand from predecessor: " << operand; collectUnderlyingAddressValues(operand, maxDepth, visited, output); } return; @@ -178,10 +163,35 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, Region *region = block->getParent(); Operation *op = region->getParentOp(); if (auto branch = dyn_cast(op)) { - return collectUnderlyingAddressValues(branch, region, arg, argNumber, - maxDepth, visited, output); + LDBG() << " Processing region branch operation for entry block"; + // We have to find the successor matching the region, so that the input + // arguments are correctly set. + // TODO: this isn't comprehensive: the successor may not be reachable from + // the entry block. + SmallVector successors; + branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + RegionSuccessor regionSuccessor(region); + bool found = false; + for (RegionSuccessor &successor : successors) { + if (successor.getSuccessor() == region) { + LDBG() << " Found matching region successor: " << successor; + found = true; + regionSuccessor = successor; + break; + } + } + if (!found) { + LDBG() + << " No matching region successor found, adding argument to output"; + output.push_back(arg); + return; + } + return collectUnderlyingAddressValues2( + branch, regionSuccessor, arg, argNumber, maxDepth, visited, output); } + LDBG() + << " Cannot reason about underlying address, adding argument to output"; // We can't reason about the underlying address of this argument. output.push_back(arg); } @@ -190,17 +200,26 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output) { + LDBG() << "collectUnderlyingAddressValues: " << value; + LDBG() << " maxDepth: " << maxDepth; + // Check that we don't infinitely recurse. - if (!visited.insert(value).second) + if (!visited.insert(value).second) { + LDBG() << " Value already visited, skipping"; return; + } if (maxDepth == 0) { + LDBG() << " Max depth reached, adding value to output"; output.push_back(value); return; } --maxDepth; - if (BlockArgument arg = dyn_cast(value)) + if (BlockArgument arg = dyn_cast(value)) { + LDBG() << " Processing as BlockArgument"; return collectUnderlyingAddressValues(arg, maxDepth, visited, output); + } + LDBG() << " Processing as OpResult"; collectUnderlyingAddressValues(cast(value), maxDepth, visited, output); } @@ -208,9 +227,11 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, /// Given a value, collect all of the underlying values being addressed. static void collectUnderlyingAddressValues(Value value, SmallVectorImpl &output) { + LDBG() << "collectUnderlyingAddressValues: " << value; DenseSet visited; collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited, output); + LDBG() << " Collected " << output.size() << " underlying values"; } //===----------------------------------------------------------------------===// @@ -227,19 +248,33 @@ static LogicalResult getAllocEffectFor(Value value, std::optional &effect, Operation *&allocScopeOp) { + LDBG() << "getAllocEffectFor: " << value; + // Try to get a memory effect interface for the parent operation. Operation *op; - if (BlockArgument arg = dyn_cast(value)) + if (BlockArgument arg = dyn_cast(value)) { op = arg.getOwner()->getParentOp(); - else + LDBG() << " BlockArgument, parent op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + } else { op = cast(value).getOwner(); + LDBG() << " OpResult, owner op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + } + MemoryEffectOpInterface interface = dyn_cast(op); - if (!interface) + if (!interface) { + LDBG() << " No memory effect interface found"; return failure(); + } // Try to find an allocation effect on the resource. - if (!(effect = interface.getEffectOnValue(value))) + if (!(effect = interface.getEffectOnValue(value))) { + LDBG() << " No allocation effect found on value"; return failure(); + } + + LDBG() << " Found allocation effect"; // If we found an allocation effect, try to find a scope for the allocation. // If the resource of this allocation is automatically scoped, find the parent @@ -247,6 +282,12 @@ getAllocEffectFor(Value value, if (llvm::isa( effect->getResource())) { allocScopeOp = op->getParentWithTrait(); + if (allocScopeOp) { + LDBG() << " Automatic allocation scope found: " + << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions()); + } else { + LDBG() << " Automatic allocation scope found: null"; + } return success(); } @@ -255,6 +296,12 @@ getAllocEffectFor(Value value, // For now assume allocation scope to the function scope (we don't care if // pointer escape outside function). allocScopeOp = op->getParentOfType(); + if (allocScopeOp) { + LDBG() << " Function scope found: " + << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions()); + } else { + LDBG() << " Function scope found: null"; + } return success(); } @@ -293,33 +340,44 @@ static std::optional checkDistinctObjects(Value lhs, Value rhs) { /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { - if (lhs == rhs) + LDBG() << "aliasImpl: " << lhs << " vs " << rhs; + + if (lhs == rhs) { + LDBG() << " Same value, must alias"; return AliasResult::MustAlias; + } + Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr; std::optional lhsAlloc, rhsAlloc; // Handle the case where lhs is a constant. Attribute lhsAttr, rhsAttr; if (matchPattern(lhs, m_Constant(&lhsAttr))) { + LDBG() << " lhs is constant"; // TODO: This is overly conservative. Two matching constants don't // necessarily map to the same address. For example, if the two values // correspond to different symbols that both represent a definition. - if (matchPattern(rhs, m_Constant(&rhsAttr))) + if (matchPattern(rhs, m_Constant(&rhsAttr))) { + LDBG() << " rhs is also constant, may alias"; return AliasResult::MayAlias; + } // Try to find an alloc effect on rhs. If an effect was found we can't // alias, otherwise we might. - return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)) - ? AliasResult::NoAlias - : AliasResult::MayAlias; + bool rhsHasAlloc = + succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)); + LDBG() << " rhs has alloc effect: " << rhsHasAlloc; + return rhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } // Handle the case where rhs is a constant. if (matchPattern(rhs, m_Constant(&rhsAttr))) { + LDBG() << " rhs is constant"; // Try to find an alloc effect on lhs. If an effect was found we can't // alias, otherwise we might. - return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)) - ? AliasResult::NoAlias - : AliasResult::MayAlias; + bool lhsHasAlloc = + succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); + LDBG() << " lhs has alloc effect: " << lhsHasAlloc; + return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } if (std::optional result = checkDistinctObjects(lhs, rhs)) @@ -329,9 +387,14 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // an allocation effect. bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)); + LDBG() << " lhs has alloc effect: " << lhsHasAlloc; + LDBG() << " rhs has alloc effect: " << rhsHasAlloc; + if (lhsHasAlloc == rhsHasAlloc) { // If both values have an allocation effect we know they don't alias, and if // neither have an effect we can't make an assumptions. + LDBG() << " Both have same alloc status: " + << (lhsHasAlloc ? "NoAlias" : "MayAlias"); return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } @@ -339,6 +402,7 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // and one without. Move the one with the effect to the lhs to make the next // checks simpler. if (rhsHasAlloc) { + LDBG() << " Swapping lhs and rhs to put alloc effect on lhs"; std::swap(lhs, rhs); lhsAlloc = rhsAlloc; lhsAllocScope = rhsAllocScope; @@ -347,49 +411,74 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // If the effect has a scoped allocation region, check to see if the // non-effect value is defined above that scope. if (lhsAllocScope) { + LDBG() << " Checking allocation scope: " + << OpWithFlags(lhsAllocScope, OpPrintingFlags().skipRegions()); // If the parent operation of rhs is an ancestor of the allocation scope, or // if rhs is an entry block argument of the allocation scope we know the two // values can't alias. Operation *rhsParentOp = rhs.getParentRegion()->getParentOp(); - if (rhsParentOp->isProperAncestor(lhsAllocScope)) + if (rhsParentOp->isProperAncestor(lhsAllocScope)) { + LDBG() << " rhs parent is ancestor of alloc scope, no alias"; return AliasResult::NoAlias; + } if (rhsParentOp == lhsAllocScope) { BlockArgument rhsArg = dyn_cast(rhs); - if (rhsArg && rhs.getParentBlock()->isEntryBlock()) + if (rhsArg && rhs.getParentBlock()->isEntryBlock()) { + LDBG() << " rhs is entry block arg of alloc scope, no alias"; return AliasResult::NoAlias; + } } } // If we couldn't reason about the relationship between the two values, // conservatively assume they might alias. + LDBG() << " Cannot reason about relationship, may alias"; return AliasResult::MayAlias; } /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { - if (lhs == rhs) + LDBG() << "alias: " << lhs << " vs " << rhs; + + if (lhs == rhs) { + LDBG() << " Same value, must alias"; return AliasResult::MustAlias; + } // Get the underlying values being addressed. SmallVector lhsValues, rhsValues; collectUnderlyingAddressValues(lhs, lhsValues); collectUnderlyingAddressValues(rhs, rhsValues); + LDBG() << " lhs underlying values: " << lhsValues.size(); + LDBG() << " rhs underlying values: " << rhsValues.size(); + // If we failed to collect for either of the values somehow, conservatively // assume they may alias. - if (lhsValues.empty() || rhsValues.empty()) + if (lhsValues.empty() || rhsValues.empty()) { + LDBG() << " Failed to collect underlying values, may alias"; return AliasResult::MayAlias; + } // Check the alias results against each of the underlying values. std::optional result; for (Value lhsVal : lhsValues) { for (Value rhsVal : rhsValues) { + LDBG() << " Checking underlying values: " << lhsVal << " vs " << rhsVal; AliasResult nextResult = aliasImpl(lhsVal, rhsVal); + LDBG() << " Result: " + << (nextResult == AliasResult::MustAlias ? "MustAlias" + : nextResult == AliasResult::NoAlias ? "NoAlias" + : "MayAlias"); result = result ? result->merge(nextResult) : nextResult; } } // We should always have a valid result here. + LDBG() << " Final result: " + << (result->isMust() ? "MustAlias" + : result->isNo() ? "NoAlias" + : "MayAlias"); return *result; } @@ -398,8 +487,12 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { //===----------------------------------------------------------------------===// ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { + LDBG() << "getModRef: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " on location " << location; + // Check to see if this operation relies on nested side effects. if (op->hasTrait()) { + LDBG() << " Operation has recursive memory effects, returning ModAndRef"; // TODO: To check recursive operations we need to check all of the nested // operations, which can result in a quadratic number of queries. We should // introduce some caching of some kind to help alleviate this, especially as @@ -410,38 +503,64 @@ ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { // Otherwise, check to see if this operation has a memory effect interface. MemoryEffectOpInterface interface = dyn_cast(op); - if (!interface) + if (!interface) { + LDBG() << " No memory effect interface, returning ModAndRef"; return ModRefResult::getModAndRef(); + } // Build a ModRefResult by merging the behavior of the effects of this // operation. SmallVector effects; interface.getEffects(effects); + LDBG() << " Found " << effects.size() << " memory effects"; ModRefResult result = ModRefResult::getNoModRef(); for (const MemoryEffects::EffectInstance &effect : effects) { - if (isa(effect.getEffect())) + if (isa(effect.getEffect())) { + LDBG() << " Skipping alloc/free effect"; continue; + } // Check for an alias between the effect and our memory location. // TODO: Add support for checking an alias with a symbol reference. AliasResult aliasResult = AliasResult::MayAlias; - if (Value effectValue = effect.getValue()) + if (Value effectValue = effect.getValue()) { + LDBG() << " Checking alias between effect value " << effectValue + << " and location " << location; aliasResult = alias(effectValue, location); + LDBG() << " Alias result: " + << (aliasResult.isMust() ? "MustAlias" + : aliasResult.isNo() ? "NoAlias" + : "MayAlias"); + } else { + LDBG() << " No effect value, assuming MayAlias"; + } // If we don't alias, ignore this effect. - if (aliasResult.isNo()) + if (aliasResult.isNo()) { + LDBG() << " No alias, ignoring effect"; continue; + } // Merge in the corresponding mod or ref for this effect. if (isa(effect.getEffect())) { + LDBG() << " Adding Ref to result"; result = result.merge(ModRefResult::getRef()); } else { assert(isa(effect.getEffect())); + LDBG() << " Adding Mod to result"; result = result.merge(ModRefResult::getMod()); } - if (result.isModAndRef()) + if (result.isModAndRef()) { + LDBG() << " Result is now ModAndRef, breaking"; break; + } } + + LDBG() << " Final ModRef result: " + << (result.isModAndRef() ? "ModAndRef" + : result.isMod() ? "Mod" + : result.isRef() ? "Ref" + : "NoModRef"); return result; } diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 377f7ebe06750..0fc5b4482bf3e 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -501,11 +501,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, return; SmallVector successors; - if (auto terminator = dyn_cast(op)) - terminator.getSuccessorRegions(*operands, successors); - else - branch.getSuccessorRegions(op->getParentRegion(), successors); - + auto terminator = dyn_cast(op); + if (!terminator) + return; + terminator.getSuccessorRegions(*operands, successors); visitRegionBranchEdges(branch, op, successors); } diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index daa3db55b2852..0682e5f26785a 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -588,7 +588,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast(block->getParentOp())) { LDBG() << " Exit block of region branch operation"; - visitRegionBranchOperation(point, branch, block->getParent(), before); + auto terminator = + cast(block->getTerminator()); + visitRegionBranchOperation(point, branch, terminator, before); return; } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 0d2e2ed85549d..8e63ae86753b4 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -130,7 +130,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast(op)) { visitRegionSuccessors(getProgramPointAfter(branch), branch, - /*successor=*/RegionBranchPoint::parent(), + /*successor=*/{branch, branch->getResults()}, resultLattices); return success(); } @@ -279,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation( void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint *point, RegionBranchOpInterface branch, - RegionBranchPoint successor, ArrayRef lattices) { + RegionSuccessor successor, ArrayRef lattices) { const auto *predecessors = getOrCreateFor(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); @@ -314,7 +314,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( visitNonControlFlowArgumentsImpl( branch, RegionSuccessor( - branch->getResults().slice(firstIndex, inputs.size())), + branch, branch->getResults().slice(firstIndex, inputs.size())), lattices, firstIndex); } else { if (!inputs.empty()) diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp index 817d71a3452ca..863f260cd4b6a 100644 --- a/mlir/lib/Analysis/SliceWalk.cpp +++ b/mlir/lib/Analysis/SliceWalk.cpp @@ -114,7 +114,7 @@ mlir::getControlFlowPredecessors(Value value) { if (!regionOp) return std::nullopt; // Add the control flow predecessor operands to the work list. - RegionSuccessor region(regionOp->getResults()); + RegionSuccessor region(regionOp, regionOp->getResults()); SmallVector predecessorOperands = getRegionPredecessorOperands( regionOp, region, opResult.getResultNumber()); return predecessorOperands; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 7e5ce26b5f733..7688ce8ecb893 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2566,8 +2566,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getRegion()) && "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && + "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2576,34 +2577,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { void AffineForOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - assert((point.isParent() || point == getRegion()) && "expected loop region"); + assert((point.isParent() || + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getRegion()) && + "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional tripCount = getTrivialConstantTripCount(*this); - if (point.isParent() && tripCount.has_value()) { - if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - return; - } - if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor(getResults())); - return; + if (tripCount.has_value()) { + if (!point.isParent()) { + // From the loop body, if the trip count is one, we can only branch back + // to the parent. + if (tripCount == 1) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } + if (tripCount == 0) + return; + } else { + if (tripCount.value() > 0) { + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + return; + } + if (tripCount.value() == 0) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } } } - // From the loop body, if the trip count is one, we can only branch back to - // the parent. - if (!point.isParent() && tripCount == 1) { - regions.push_back(RegionSuccessor(getResults())); - return; - } - // In all other cases, the loop may branch back to itself or the parent // operation. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } /// Returns true if the affine.for has zero iterations in trivial cases. @@ -3013,7 +3021,7 @@ void AffineIfOp::getSuccessorRegions( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { regions.push_back( RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3023,7 +3031,7 @@ void AffineIfOp::getSuccessorRegions( // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } LogicalResult AffineIfOp::verify() { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index dc7b07d911c17..8e4a49df76b52 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -36,8 +36,9 @@ void AsyncDialect::initialize() { constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; -OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBodyRegion() && "invalid region index"); +OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBodyRegion() && + "invalid region index"); return getBodyOperands(); } @@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. - if (point == getBodyRegion()) { - regions.push_back(RegionSuccessor(getBodyResults())); + if (!point.isParent() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBodyRegion()) { + regions.push_back(RegionSuccessor(getOperation(), getBodyResults())); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index b593ccab060c7..36a759c279eb7 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -562,8 +562,11 @@ LogicalResult BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) { SmallVector returnOperandTypes(llvm::map_range( op.getFunctionBody().getOps(), - [](RegionBranchTerminatorOpInterface op) { - return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes(); + [&](RegionBranchTerminatorOpInterface branchOp) { + return branchOp + .getSuccessorOperands(RegionSuccessor( + op.getOperation(), op.getOperation()->getResults())) + .getTypes(); })); if (!llvm::all_equal(returnOperandTypes)) return op->emitError( @@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // about, but we would need to check how many successors there are and under // which condition they are taken, etc. - MutableOperandRange operands = - op.getMutableSuccessorOperands(RegionBranchPoint::parent()); + MutableOperandRange operands = op.getMutableSuccessorOperands( + RegionSuccessor(op.getOperation(), op.getOperation()->getResults())); SmallVector updatedOwnerships; auto result = deallocation_impl::insertDeallocOpForReturnLike( diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 4754f0bfe895e..0992ce14b4afb 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); return; } @@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index b5f8ddaadacdf..6c6d8d2bad55d 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda296da5..075d37cb600e6 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a9da6c2c8320a..736f15c1ca520 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -305,7 +305,7 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } //===----------------------------------------------------------------------===// @@ -313,10 +313,11 @@ void ExecuteRegionOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getParentOp().getAfter()) && - "condition op can only exit the loop or branch to the after" - "region"); +ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { + assert( + (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } @@ -334,7 +335,7 @@ void ConditionOp::getSuccessorRegions( regions.emplace_back(&whileOp.getAfter(), whileOp.getAfter().getArguments()); if (!boolAttr || !boolAttr.getValue()) - regions.emplace_back(whileOp.getResults()); + regions.emplace_back(whileOp.getOperation(), whileOp.getResults()); } //===----------------------------------------------------------------------===// @@ -657,7 +658,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) { return dyn_cast_or_null(containingOp); } -OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -667,7 +668,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -1961,9 +1962,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. if (point.isParent()) - regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); else - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); } //===----------------------------------------------------------------------===// @@ -2241,9 +2243,10 @@ void IfOp::print(OpAsmPrinter &p) { void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - // The `then` and the `else` region branch back to the parent operation. + // The `then` and the `else` region branch back to the parent operation or one + // of the recursive parent operations (early exit case). if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -2252,7 +2255,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -2269,7 +2273,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } } @@ -3293,7 +3297,8 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + getOperation(), ResultRange{getResults().end(), getResults().end()})); } //===----------------------------------------------------------------------===// @@ -3339,7 +3344,7 @@ LogicalResult ReduceOp::verifyRegions() { } MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { +ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) { // No operands are forwarded to the next iteration. return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } @@ -3422,8 +3427,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() { return getBeforeArguments(); } -OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); } @@ -3436,15 +3441,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, return; } - assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && + assert(llvm::is_contained( + {&getAfter(), &getBefore()}, + point.getTerminatorPredecessorOrNull()->getParentRegion()) && "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (point == getAfter()) { + if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); regions.emplace_back(&getAfter(), getAfter().getArguments()); } @@ -4353,7 +4361,7 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.emplace_back(getResults()); + successors.emplace_back(getOperation(), getResults()); return; } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index ae52af5009dc9..ddcbda86cf1f3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -23,7 +23,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::ForOp; using scf::WhileOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1e1056e..00bef707fadd3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -21,7 +21,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::LoopNest; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 5ba828918c22a..f0f22e5ef4a83 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions( // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 1a9d9e158ee75..3962e3e84dd31 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2597,7 +2597,7 @@ std::optional> IterateOp::getYieldedValuesMutable() { std::optional IterateOp::getLoopResults() { return getResults(); } -OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point, // or back into the operation itself. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 365afab3764c8..062606e7e10b6 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange -transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { - if (!point.isParent() && getOperation()->getNumOperands() == 1) +OperandRange transform::AlternativesOp::getEntrySuccessorOperands( + RegionSuccessor successor) { + if (!successor.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); @@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { for (Region &alternative : llvm::drop_begin( - getAlternatives(), - point.isParent() ? 0 - : point.getRegionOrNull()->getRegionNumber() + 1)) { + getAlternatives(), point.isParent() + ? 0 + : point.getTerminatorPredecessorOrNull() + ->getParentRegion() + ->getRegionNumber() + + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (!point.isParent()) - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions( } // Branch back to the region or the parent. - assert(point == getBody() && "unexpected region index"); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } OperandRange -transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { +transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) { // Each block argument handle is mapped to a subset (one op to be precise) // of the payload of the corresponding `targets` operand of ForeachOp. - assert(point == getBody() && "unexpected region index"); + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects( } OperandRange -transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody() && "unexpected region index"); +transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), @@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions( return; } - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index c627158e999ed..f727118f3f9a0 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" @@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, } OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { // No operands will be forwarded to the region(s). return getOperands().slice(0, 0); } @@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions( for (Region &alternative : getAlternatives()) regions.emplace_back(&alternative, Block::BlockArgListType()); else - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::tune::AlternativesOp::getRegionInvocationBounds( diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6588c71..f4c9242ed3479 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -138,6 +138,10 @@ Diagnostic &Diagnostic::operator<<(Operation &op) { return appendOp(op, OpPrintingFlags()); } +Diagnostic &Diagnostic::operator<<(OpWithFlags op) { + return appendOp(*op.getOperation(), op.flags()); +} + Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) { std::string str; llvm::raw_string_ostream os(str); diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 46b6298076d48..15a941f380225 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -253,6 +253,21 @@ void Region::OpIterator::skipOverBlocksWithNoOps() { operation = block->begin(); } +llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, Region ®ion) { + if (!region.getParentOp()) { + os << "Region has no parent op"; + } else { + os << "Region #" << region.getRegionNumber() << " in operation " + << region.getParentOp()->getName(); + } + for (auto it : llvm::enumerate(region.getBlocks())) { + os << "\n Block #" << it.index() << ":"; + for (Operation &op : it.value().getOperations()) + os << "\n " << OpWithFlags(&op, OpPrintingFlags().skipRegions()); + } + return os; +} + //===----------------------------------------------------------------------===// // RegionRange //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index ca3f7666dba8a..1e56810ff7aaf 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,7 +9,9 @@ #include #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; @@ -38,20 +40,31 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, std::optional detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { + LDBG() << "Getting branch successor argument for operand index " + << operandIndex << " in successor block"; + OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. - if (forwardedOperands.empty()) + if (forwardedOperands.empty()) { + LDBG() << "No forwarded operands, returning nullopt"; return std::nullopt; + } // Check to ensure that this operand is within the range. unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || - operandIndex >= (operandsStart + forwardedOperands.size())) + operandIndex >= (operandsStart + forwardedOperands.size())) { + LDBG() << "Operand index " << operandIndex << " out of range [" + << operandsStart << ", " + << (operandsStart + forwardedOperands.size()) + << "), returning nullopt"; return std::nullopt; + } // Index the successor. unsigned argIndex = operands.getProducedOperandCount() + operandIndex - operandsStart; + LDBG() << "Computed argument index " << argIndex << " for successor block"; return successor->getArgument(argIndex); } @@ -59,9 +72,15 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands, LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands) { + LDBG() << "Verifying branch successor operands for successor #" << succNo + << " in operation " << op->getName(); + // Check the count. unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); + LDBG() << "Branch has " << operandCount << " operands, target block has " + << destBB->getNumArguments() << " arguments"; + if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount << " operands for successor #" << succNo @@ -69,13 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, << destBB->getNumArguments(); // Check the types. + LDBG() << "Checking type compatibility for " + << (operandCount - operands.getProducedOperandCount()) + << " forwarded operands"; for (unsigned i = operands.getProducedOperandCount(); i != operandCount; ++i) { - if (!cast(op).areTypesCompatible( - operands[i].getType(), destBB->getArgument(i).getType())) + Type operandType = operands[i].getType(); + Type argType = destBB->getArgument(i).getType(); + LDBG() << "Checking type compatibility: operand type " << operandType + << " vs argument type " << argType; + + if (!cast(op).areTypesCompatible(operandType, argType)) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } + + LDBG() << "Branch successor operand verification successful"; return success(); } @@ -126,15 +154,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) { static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, - RegionBranchPoint succRegionNo) { + RegionSuccessor succRegionNo) { diag << "from "; - if (Region *region = sourceNo.getRegionOrNull()) - diag << "Region #" << region->getRegionNumber(); + if (Operation *op = sourceNo.getTerminatorPredecessorOrNull()) + diag << "Operation " << op->getName(); else diag << "parent operands"; diag << " to "; - if (Region *region = succRegionNo.getRegionOrNull()) + if (Region *region = succRegionNo.getSuccessor()) diag << "Region #" << region->getRegionNumber(); else diag << "parent results"; @@ -145,13 +173,12 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the /// types of the inputs that flow to a successor region. static LogicalResult -verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, - function_ref(RegionBranchPoint)> +verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, + RegionBranchPoint sourcePoint, + function_ref(RegionSuccessor)> getInputsTypesForRegion) { - auto regionInterface = cast(op); - SmallVector successors; - regionInterface.getSuccessorRegions(sourcePoint, successors); + branchOp.getSuccessorRegions(sourcePoint, successors); for (RegionSuccessor &succ : successors) { FailureOr sourceTypes = getInputsTypesForRegion(succ); @@ -160,10 +187,14 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { - InFlightDiagnostic diag = op->emitOpError("region control flow edge "); + InFlightDiagnostic diag = + branchOp->emitOpError("region control flow edge "); + std::string succStr; + llvm::raw_string_ostream os(succStr); + os << succ; return printRegionEdgeName(diag, sourcePoint, succ) << ": source has " << sourceTypes->size() - << " operands, but target successor needs " + << " operands, but target successor " << os.str() << " needs " << succInputsTypes.size(); } @@ -171,8 +202,10 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); - if (!regionInterface.areTypesCompatible(sourceType, inputType)) { - InFlightDiagnostic diag = op->emitOpError("along control flow edge "); + + if (!branchOp.areTypesCompatible(sourceType, inputType)) { + InFlightDiagnostic diag = + branchOp->emitOpError("along control flow edge "); return printRegionEdgeName(diag, sourcePoint, succ) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " @@ -180,6 +213,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, } } } + return success(); } @@ -187,34 +221,18 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { auto regionInterface = cast(op); - auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange { - return regionInterface.getEntrySuccessorOperands(point).getTypes(); + auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange { + return regionInterface.getEntrySuccessorOperands(successor).getTypes(); }; // Verify types along control flow edges originating from the parent. - if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(), - inputTypesFromParent))) + if (failed(verifyTypesAlongAllEdges( + regionInterface, RegionBranchPoint::parent(), inputTypesFromParent))) return failure(); - auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { - if (lhs.size() != rhs.size()) - return false; - for (auto types : llvm::zip(lhs, rhs)) { - if (!regionInterface.areTypesCompatible(std::get<0>(types), - std::get<1>(types))) { - return false; - } - } - return true; - }; - // Verify types along control flow edges originating from each region. for (Region ®ion : op->getRegions()) { - - // Since there can be multiple terminators implementing the - // `RegionBranchTerminatorOpInterface`, all should have the same operand - // types when passing them to the same region. - + // Collect all return-like terminators in the region. SmallVector regionReturnOps; for (Block &block : region) if (!block.empty()) @@ -227,33 +245,20 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { if (regionReturnOps.empty()) continue; - auto inputTypesForRegion = - [&](RegionBranchPoint point) -> FailureOr { - std::optional regionReturnOperands; - for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { - auto terminatorOperands = regionReturnOp.getSuccessorOperands(point); - - if (!regionReturnOperands) { - regionReturnOperands = terminatorOperands; - continue; - } - - // Found more than one ReturnLike terminator. Make sure the operand - // types match with the first one. - if (!areTypesCompatible(regionReturnOperands->getTypes(), - terminatorOperands.getTypes())) { - InFlightDiagnostic diag = op->emitOpError("along control flow edge"); - return printRegionEdgeName(diag, region, point) - << " operands mismatch between return-like terminators"; - } - } - - // All successors get the same set of operand types. - return TypeRange(regionReturnOperands->getTypes()); - }; - - if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion))) - return failure(); + // Verify types along control flow edges originating from each return-like + // terminator. + for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { + + auto inputTypesForRegion = + [&](RegionSuccessor successor) -> FailureOr { + OperandRange terminatorOperands = + regionReturnOp.getSuccessorOperands(successor); + return TypeRange(terminatorOperands.getTypes()); + }; + if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp, + inputTypesForRegion))) + return failure(); + } } return success(); @@ -272,31 +277,74 @@ using StopConditionFn = function_ref visited)>; static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn) { auto op = cast(begin->getParentOp()); + LDBG() << "Starting region graph traversal from region #" + << begin->getRegionNumber() << " in operation " << op->getName(); + SmallVector visited(op->getNumRegions(), false); visited[begin->getRegionNumber()] = true; + LDBG() << "Initialized visited array with " << op->getNumRegions() + << " regions"; // Retrieve all successors of the region and enqueue them in the worklist. SmallVector worklist; auto enqueueAllSuccessors = [&](Region *region) { - SmallVector successors; - op.getSuccessorRegions(region, successors); - for (RegionSuccessor successor : successors) - if (!successor.isParent()) - worklist.push_back(successor.getSuccessor()); + LDBG() << "Enqueuing successors for region #" << region->getRegionNumber(); + SmallVector operandAttributes(op->getNumOperands()); + for (Block &block : *region) { + if (block.empty()) + continue; + auto terminator = + dyn_cast(block.back()); + if (!terminator) + continue; + SmallVector successors; + operandAttributes.resize(terminator->getNumOperands()); + terminator.getSuccessorRegions(operandAttributes, successors); + LDBG() << "Found " << successors.size() + << " successors from terminator in block"; + for (RegionSuccessor successor : successors) { + if (!successor.isParent()) { + worklist.push_back(successor.getSuccessor()); + LDBG() << "Added region #" + << successor.getSuccessor()->getRegionNumber() + << " to worklist"; + } else { + LDBG() << "Skipping parent successor"; + } + } + } }; enqueueAllSuccessors(begin); + LDBG() << "Initial worklist size: " << worklist.size(); // Process all regions in the worklist via DFS. while (!worklist.empty()) { Region *nextRegion = worklist.pop_back_val(); - if (stopConditionFn(nextRegion, visited)) + LDBG() << "Processing region #" << nextRegion->getRegionNumber() + << " from worklist (remaining: " << worklist.size() << ")"; + + if (stopConditionFn(nextRegion, visited)) { + LDBG() << "Stop condition met for region #" + << nextRegion->getRegionNumber() << ", returning true"; return true; - if (visited[nextRegion->getRegionNumber()]) + } + llvm::dbgs() << "Region: " << nextRegion << "\n"; + if (!nextRegion->getParentOp()) { + llvm::errs() << "Region " << *nextRegion << " has no parent op\n"; + return false; + } + if (visited[nextRegion->getRegionNumber()]) { + LDBG() << "Region #" << nextRegion->getRegionNumber() + << " already visited, skipping"; continue; + } visited[nextRegion->getRegionNumber()] = true; + LDBG() << "Marking region #" << nextRegion->getRegionNumber() + << " as visited"; enqueueAllSuccessors(nextRegion); } + LDBG() << "Traversal completed, returning false"; return false; } @@ -322,18 +370,26 @@ static bool isRegionReachable(Region *begin, Region *r) { /// mutually exclusive if they are not reachable from each other as per /// RegionBranchOpInterface::getSuccessorRegions. bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { + LDBG() << "Checking if operations are in mutually exclusive regions: " + << a->getName() << " and " << b->getName(); + assert(a && "expected non-empty operation"); assert(b && "expected non-empty operation"); auto branchOp = a->getParentOfType(); while (branchOp) { + LDBG() << "Checking branch operation " << branchOp->getName(); + // Check if b is inside branchOp. (We already know that a is.) if (!branchOp->isProperAncestor(b)) { + LDBG() << "Operation b is not inside branchOp, checking next ancestor"; // Check next enclosing RegionBranchOpInterface. branchOp = branchOp->getParentOfType(); continue; } + LDBG() << "Both operations are inside branchOp, finding their regions"; + // b is contained in branchOp. Retrieve the regions in which `a` and `b` // are contained. Region *regionA = nullptr, *regionB = nullptr; @@ -341,63 +397,136 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { if (r.findAncestorOpInRegion(*a)) { assert(!regionA && "already found a region for a"); regionA = &r; + LDBG() << "Found region #" << r.getRegionNumber() << " for operation a"; } if (r.findAncestorOpInRegion(*b)) { assert(!regionB && "already found a region for b"); regionB = &r; + LDBG() << "Found region #" << r.getRegionNumber() << " for operation b"; } } assert(regionA && regionB && "could not find region of op"); + LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #" + << regionB->getRegionNumber(); + // `a` and `b` are in mutually exclusive regions if both regions are // distinct and neither region is reachable from the other region. - return regionA != regionB && !isRegionReachable(regionA, regionB) && - !isRegionReachable(regionB, regionA); + bool regionsAreDistinct = (regionA != regionB); + bool aNotReachableFromB = !isRegionReachable(regionA, regionB); + bool bNotReachableFromA = !isRegionReachable(regionB, regionA); + + LDBG() << "Regions distinct: " << regionsAreDistinct + << ", A not reachable from B: " << aNotReachableFromB + << ", B not reachable from A: " << bNotReachableFromA; + + bool mutuallyExclusive = + regionsAreDistinct && aNotReachableFromB && bNotReachableFromA; + LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive; + + return mutuallyExclusive; } // Could not find a common RegionBranchOpInterface among a's and b's // ancestors. + LDBG() << "No common RegionBranchOpInterface found, operations are not " + "mutually exclusive"; return false; } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { + LDBG() << "Checking if region #" << index << " is repetitive in operation " + << getOperation()->getName(); + Region *region = &getOperation()->getRegion(index); - return isRegionReachable(region, region); + bool isRepetitive = isRegionReachable(region, region); + + LDBG() << "Region #" << index << " is repetitive: " << isRepetitive; + return isRepetitive; } bool RegionBranchOpInterface::hasLoop() { + LDBG() << "Checking if operation " << getOperation()->getName() + << " has loops"; + SmallVector entryRegions; getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); - for (RegionSuccessor successor : entryRegions) - if (!successor.isParent() && - traverseRegionGraph(successor.getSuccessor(), - [](Region *nextRegion, ArrayRef visited) { - // Interrupt traversal if the region was already - // visited. - return visited[nextRegion->getRegionNumber()]; - })) - return true; + LDBG() << "Found " << entryRegions.size() << " entry regions"; + + for (RegionSuccessor successor : entryRegions) { + if (!successor.isParent()) { + LDBG() << "Checking entry region #" + << successor.getSuccessor()->getRegionNumber() << " for loops"; + + bool hasLoop = + traverseRegionGraph(successor.getSuccessor(), + [](Region *nextRegion, ArrayRef visited) { + // Interrupt traversal if the region was already + // visited. + return visited[nextRegion->getRegionNumber()]; + }); + + if (hasLoop) { + LDBG() << "Found loop in entry region #" + << successor.getSuccessor()->getRegionNumber(); + return true; + } + } else { + LDBG() << "Skipping parent successor"; + } + } + + LDBG() << "No loops found in operation"; return false; } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { + LDBG() << "Finding enclosing repetitive region for operation " + << op->getName(); + while (Region *region = op->getParentRegion()) { + LDBG() << "Checking region #" << region->getRegionNumber() + << " in operation " << region->getParentOp()->getName(); + op = region->getParentOp(); - if (auto branchOp = dyn_cast(op)) - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + if (auto branchOp = dyn_cast(op)) { + LDBG() + << "Found RegionBranchOpInterface, checking if region is repetitive"; + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { + LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; + } + } else { + LDBG() << "Parent operation does not implement RegionBranchOpInterface"; + } } + + LDBG() << "No enclosing repetitive region found"; return nullptr; } Region *mlir::getEnclosingRepetitiveRegion(Value value) { + LDBG() << "Finding enclosing repetitive region for value"; + Region *region = value.getParentRegion(); while (region) { + LDBG() << "Checking region #" << region->getRegionNumber() + << " in operation " << region->getParentOp()->getName(); + Operation *op = region->getParentOp(); - if (auto branchOp = dyn_cast(op)) - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + if (auto branchOp = dyn_cast(op)) { + LDBG() + << "Found RegionBranchOpInterface, checking if region is repetitive"; + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { + LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; + } + } else { + LDBG() << "Parent operation does not implement RegionBranchOpInterface"; + } region = op->getParentRegion(); } + + LDBG() << "No enclosing repetitive region found for value"; return nullptr; } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index e0c65b0e09774..41f3f9d76a3b1 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -432,8 +432,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Return the successors of `region` if the latter is not null. Else return // the successors of `regionBranchOp`. - auto getSuccessors = [&](Region *region = nullptr) { - auto point = region ? region : RegionBranchPoint::parent(); + auto getSuccessors = [&](RegionBranchPoint point) { SmallVector successors; regionBranchOp.getSuccessorRegions(point, successors); return successors; @@ -456,7 +455,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // `nonForwardedOperands`. auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors()) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint::parent())) { for (OpOperand *opOperand : getForwardedOpOperands(successor)) nonForwardedOperands.reset(opOperand->getOperandNumber()); } @@ -469,10 +469,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; + // TODO: this isn't correct in face of multiple terminators. Operation *terminator = region.front().getTerminator(); nonForwardedRets[terminator] = BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint( + cast(terminator)))) { for (OpOperand *opOperand : getForwardedOpOperands(successor, terminator)) nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); @@ -489,8 +492,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, DenseMap &argsToKeep, Region *region = nullptr) { Operation *terminator = region ? region->front().getTerminator() : nullptr; + RegionBranchPoint point = + terminator + ? RegionBranchPoint( + cast(terminator)) + : RegionBranchPoint::parent(); - for (const RegionSuccessor &successor : getSuccessors(region)) { + for (const RegionSuccessor &successor : getSuccessors(point)) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), @@ -517,7 +525,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, resultsOrArgsToKeepChanged = false; // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : getSuccessors()) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint::parent())) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor), @@ -551,7 +560,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, if (region.empty()) continue; Operation *terminator = region.front().getTerminator(); - for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint( + cast(terminator)))) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 37fc86b18e7f0..3f481ad5dbba7 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -373,7 +373,7 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) { func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32) { - // expected-error@+1 {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 2}} + // expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor needs 2}} %x, %y = scf.if %arg0 -> (f32, f32) { %0 = arith.addf %arg1, %arg1 : f32 scf.yield %0 : f32 @@ -544,7 +544,7 @@ func.func @while_invalid_terminator() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}} + // expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor needs 1}} scf.while : () -> () { scf.condition(%true) } do { @@ -557,7 +557,7 @@ func.func @while_cross_region_type_mismatch() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}} + // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}} %0 = scf.while : () -> (i1) { scf.condition(%true) %true : i1 } do { @@ -570,7 +570,7 @@ func.func @while_cross_region_type_mismatch() { func.func @while_result_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}} + // expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor needs 0}} scf.while : () -> () { scf.condition(%true) %true : i1 } do { diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index eb0d9801e7d3f..7a7a58384fbb8 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -66,7 +66,7 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionBranchPoint regionTo, + RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) override; @@ -240,7 +240,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer( void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { + RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) { LDBG() << "visitRegionBranchControlFlowTransfer: " << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region"); diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 53055fea215b7..beb9b4e51fcea 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { parser.getCurrentLocation(), result.operands); } -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, + successor.getSuccessor()) && "invalid region index"); return getOperands(); } @@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { // We always branch to the join region. if (!point.isParent()) { - if (point != getJoinRegion()) + if (point.getTerminatorPredecessorOrNull()->getParentRegion() != + &getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, if (point.isParent()) regions.emplace_back(&getRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } void AnyCondOp::getRegionInvocationBounds( @@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions( if (point.isParent()) return; - regions.emplace_back((*this)->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody()); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody()); return MutableOperandRange(getInitMutable()); } @@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { //===----------------------------------------------------------------------===// MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { - if (point.isParent()) +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { + if (successor.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } @@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions( if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// @@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions( // enter the body. regions.emplace_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6329d61ba691b..75c6c54960088 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2580,7 +2580,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [ NoTerminator, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ]> { let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); let regions = (region VariadicRegion>:$caseRegions); diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp index f1aae15393fd3..2e6950fca6be2 100644 --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -13,17 +13,24 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Parser/Parser.h" +#include "llvm/Support/DebugLog.h" #include using namespace mlir; /// A dummy op that is also a terminator. -struct DummyOp : public Op { +struct DummyOp : public Op { using Op::Op; static ArrayRef getAttributeNames() { return {}; } static StringRef getOperationName() { return "cftest.dummy_op"; } + + MutableOperandRange getMutableSuccessorOperands(RegionSuccessor point) { + return MutableOperandRange(getOperation(), 0, 0); + } }; /// All regions of this op are mutually exclusive. @@ -39,6 +46,8 @@ struct MutuallyExclusiveRegionsOp // Regions have no successors. void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) {} + using RegionBranchOpInterface::Trait< + MutuallyExclusiveRegionsOp>::getSuccessorRegions; }; /// All regions of this op call each other in a large circle. @@ -53,13 +62,18 @@ struct LoopRegionsOp void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (Region *region = point.getRegionOrNull()) { - if (point == (*this)->getRegion(1)) + if (point.getTerminatorPredecessorOrNull()) { + Region *region = + point.getTerminatorPredecessorOrNull()->getParentRegion(); + if (region == &(*this)->getRegion(1)) // This region also branches back to the parent. - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation()->getParentOp(), + getOperation()->getParentOp()->getResults())); regions.push_back(RegionSuccessor(region)); } } + using RegionBranchOpInterface::Trait::getSuccessorRegions; }; /// Each region branches back it itself or the parent. @@ -75,11 +89,17 @@ struct DoubleLoopRegionsOp void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (Region *region = point.getRegionOrNull()) { - regions.push_back(RegionSuccessor()); + if (point.getTerminatorPredecessorOrNull()) { + Region *region = + point.getTerminatorPredecessorOrNull()->getParentRegion(); + regions.push_back( + RegionSuccessor(getOperation()->getParentOp(), + getOperation()->getParentOp()->getResults())); regions.push_back(RegionSuccessor(region)); } } + using RegionBranchOpInterface::Trait< + DoubleLoopRegionsOp>::getSuccessorRegions; }; /// Regions are executed sequentially. @@ -93,11 +113,15 @@ struct SequentialRegionsOp // Region 0 has Region 1 as a successor. void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point == (*this)->getRegion(0)) { + if (point.getTerminatorPredecessorOrNull() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &(*this)->getRegion(0)) { Operation *thisOp = this->getOperation(); regions.push_back(RegionSuccessor(&thisOp->getRegion(1))); } } + using RegionBranchOpInterface::Trait< + SequentialRegionsOp>::getSuccessorRegions; }; /// A dialect putting all the above together.