Skip to content

Commit

Permalink
Reland "[mlir] Use a type for representing branch points in `RegionBr…
Browse files Browse the repository at this point in the history
…anchOpInterface`"

This reverts commit b26bb30.
  • Loading branch information
zero9178 committed Aug 30, 2023
1 parent 82e851a commit 4dd744a
Show file tree
Hide file tree
Showing 23 changed files with 258 additions and 241 deletions.
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3467,10 +3467,10 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control.
void fir::IfOp::getSuccessorRegions(
std::optional<unsigned> index,
mlir::RegionBranchPoint point,
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (index) {
if (!point.isParent()) {
regions.push_back(mlir::RegionSuccessor(getResults()));
return;
}
Expand Down
15 changes: 7 additions & 8 deletions mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// any effect on the lattice that isn't already expressed by the interface
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
Expand Down Expand Up @@ -382,7 +382,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// of the branch operation itself.
void visitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch,
std::optional<unsigned> regionNo,
RegionBranchPoint branchPoint,
AbstractDenseLattice *before);

/// Visit an operation for which the data flow is described by the
Expand Down Expand Up @@ -472,9 +472,8 @@ class DenseBackwardDataFlowAnalysis
/// nullptr`. The behavior can be further refined for specific pairs of "from"
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const LatticeT &after,
LatticeT *before) {
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
Expand Down Expand Up @@ -508,8 +507,8 @@ class DenseBackwardDataFlowAnalysis
static_cast<LatticeT *>(before));
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
std::optional<unsigned> successorIndex,
RegionBranchPoint successor,
ArrayRef<AbstractSparseLattice *> lattices);
};

Expand Down
62 changes: 62 additions & 0 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,68 @@ class RegionSuccessor {
ValueRange inputs;
};

/// This class represents a point being branched from in the methods of the
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
/// * A region within the parent operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }

/// Creates a `RegionBranchPoint` that branches from the given region.
/// The pointer must not be null.
RegionBranchPoint(Region *region) : maybeRegion(region) {
assert(region && "Region must not be null");
}

RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}

/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;

/// Constructs a `RegionBranchPoint` from the the target of a
/// `RegionSuccessor` instance.
RegionBranchPoint(RegionSuccessor successor) {
if (successor.isParent())
maybeRegion = nullptr;
else
maybeRegion = successor.getSuccessor();
}

/// Assigns a region being branched from.
RegionBranchPoint &operator=(Region &region) {
maybeRegion = &region;
return *this;
}

/// Returns true if branching from the parent op.
bool isParent() const { return maybeRegion == nullptr; }

/// Returns the region if branching from a region.
/// A null pointer otherwise.
Region *getRegionOrNull() const { return maybeRegion; }

/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return lhs.maybeRegion == rhs.maybeRegion;
}

private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
constexpr RegionBranchPoint() : maybeRegion(nullptr) {}

/// Internal encoding. Uses nullptr for representing branching from the parent
/// op and the region being branched from otherwise.
Region *maybeRegion;
};

inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}

/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
Expand Down
39 changes: 17 additions & 22 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
entering the region at `index`, which was specified as a successor of
branching from `point`, which was specified as a successor of
this operation by `getEntrySuccessorRegions`, or the operands forwarded
to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
(ins "::std::optional<unsigned>":$index), [{}],
(ins "::mlir::RegionBranchPoint":$point), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
Expand All @@ -162,22 +162,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
[{}], [{
$_op.getSuccessorRegions(std::nullopt, regions);
$_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
}]
>,
InterfaceMethod<[{
Returns the viable successors of a region at `index`, or the possible
successors when branching from the parent op if `index` is None. These
are the regions that may be selected during the flow of control. The
parent operation, i.e. a null `index`, may specify itself as successor,
which indicates that the control flow may not enter any region at all.
This method allows for describing which regions may be executed when
entering an operation, and which regions are executed after having
executed another region of the parent op. The successor region must be
non-empty.
Returns the viable successors of `point`. These are the regions that may
be selected during the flow of control. The parent operation, may
specify itself as successor, which indicates that the control flow may
not enter any region at all. This method allows for describing which
regions may be executed when entering an operation, and which regions
are executed after having executed another region of the parent op. The
successor region must be non-empty.
}],
"void", "getSuccessorRegions",
(ins "::std::optional<unsigned>":$index,
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
Expand Down Expand Up @@ -245,12 +243,10 @@ def RegionBranchTerminatorOpInterface :
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by
passing them to the region successor given by `index`. If `index` is
None, this function returns the operands that are passed as a result to
the parent operation.
passing them to the region successor given by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
(ins "::std::optional<unsigned>":$index)
(ins "::mlir::RegionBranchPoint":$point)
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to after this
Expand All @@ -269,8 +265,7 @@ def RegionBranchTerminatorOpInterface :
[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
regions);
.getSuccessorRegions(op->getParentRegion(), regions);
}]
>,
];
Expand All @@ -290,8 +285,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
return getMutableSuccessorOperands(index);
::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
return getMutableSuccessorOperands(point);
}
}];
}
Expand All @@ -309,7 +304,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
::std::optional<unsigned> index) {
::mlir::RegionBranchPoint point) {
return ::mlir::MutableOperandRange(*this);
}
}]
Expand Down
25 changes: 12 additions & 13 deletions mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
// this region predecessor that correspond to the input values of `region`. If
// an index could not be found, std::nullopt is returned instead.
auto getOperandIndexIfPred =
[&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
[&](RegionBranchPoint pred) -> std::optional<unsigned> {
SmallVector<RegionSuccessor, 2> successors;
branch.getSuccessorRegions(predIndex, successors);
branch.getSuccessorRegions(pred, successors);
for (RegionSuccessor &successor : successors) {
if (successor.getSuccessor() != region)
continue;
Expand Down Expand Up @@ -75,28 +75,27 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
};

// Check branches from the parent operation.
std::optional<unsigned> regionIndex;
if (region) {
// Determine the actual region number from the passed region.
regionIndex = region->getRegionNumber();
}
auto branchPoint = RegionBranchPoint::parent();
if (region)
branchPoint = region;

if (std::optional<unsigned> operandIndex =
getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
collectUnderlyingAddressValues(
branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
Operation *op = branch.getOperation();
for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
for (Block &block : op->getRegion(i)) {
for (Region &region : op->getRegions()) {
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
for (Block &block : region) {
// Try to determine possible region-branch successor operands for the
// current region.
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator())) {
collectUnderlyingAddressValues(
term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
} else if (block.getNumSuccessors()) {
// Otherwise, if this terminator may exit the region we can't make
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {

// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
return visitRegionBranchOperation(op, branch, std::nullopt, before);
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);

Expand Down Expand Up @@ -368,8 +369,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
visitRegionBranchOperation(block, branch,
block->getParent()->getRegionNumber(), before);
visitRegionBranchOperation(block, branch, block->getParent(), before);
return;
}

Expand All @@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {

void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch,
std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
RegionBranchPoint branchPoint, AbstractDenseLattice *before) {

// The successors of the operation may be either the first operation of the
// entry block of each possible successor region, or the next operation when
// the branch is a successor of itself.
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(regionNo, successors);
branch.getSuccessorRegions(branchPoint, successors);
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
Expand All @@ -423,10 +423,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
else
after = getLatticeFor(point, &successorBlock->front());
}
std::optional<unsigned> successorNo =
successor.isParent() ? std::optional<unsigned>()
: successor.getSuccessor()->getRegionNumber();
visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,

visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
before);
}
}
Expand Down
22 changes: 8 additions & 14 deletions mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return visitRegionSuccessors({branch}, branch,
/*successorIndex=*/std::nullopt,
/*successor=*/RegionBranchPoint::parent(),
resultLattices);
}

Expand Down Expand Up @@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {

// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
return visitRegionSuccessors(
block, branch, block->getParent()->getRegionNumber(), argLattices);
return visitRegionSuccessors(block, branch, block->getParent(),
argLattices);
}

// Otherwise, we can't reason about the data-flow.
Expand Down Expand Up @@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {

void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
std::optional<unsigned> successorIndex,
ArrayRef<AbstractSparseLattice *> lattices) {
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
Expand All @@ -224,11 +223,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(

// Check if the predecessor is the parent op.
if (op == branch) {
operands = branch.getEntrySuccessorOperands(successorIndex);
operands = branch.getEntrySuccessorOperands(successor);
// Otherwise, try to deduce the operands from a region return-like op.
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
operands = regionTerminator.getSuccessorOperands(successorIndex);
operands = regionTerminator.getSuccessorOperands(successor);
}

if (!operands) {
Expand Down Expand Up @@ -501,10 +500,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
BitVector unaccounted(op->getNumOperands(), true);

for (RegionSuccessor &successor : successors) {
Region *region = successor.getSuccessor();
OperandRange operands =
region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
: branch.getEntrySuccessorOperands({});
OperandRange operands = branch.getEntrySuccessorOperands(successor);
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
Expand Down Expand Up @@ -538,9 +534,7 @@ void AbstractSparseBackwardDataFlowAnalysis::

for (const RegionSuccessor &successor : successors) {
ValueRange inputs = successor.getSuccessorInputs();
Region *region = successor.getSuccessor();
OperandRange operands = terminator.getSuccessorOperands(
region ? region->getRegionNumber() : std::optional<unsigned>{});
OperandRange operands = terminator.getSuccessorOperands(successor);
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
Expand Down

0 comments on commit 4dd744a

Please sign in to comment.