-
Couldn't load subscription status.
- Fork 15k
[MLIR] Revamp RegionBranchOpInterface #165429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition: - A `RegionBranchPoint` is either the parent (`RegionBranchOpInterface`) op or a `RegionBranchTerminatorOpInterface` operation in a nested region. - A `RegionSuccessor` is either one of the nested region or the parent `RegionBranchOpInterface` Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.
|
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesThis is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Try to reland #161575 when the build is fixed. Patch is 116.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165429.diff 38 Files Affected:
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index d0164f32d9b6a..4f97acaa88b7a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions(
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(mlir::RegionSuccessor(getResults()));
+ regions.push_back(mlir::RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -4494,7 +4494,8 @@ void fir::IfOp::getSuccessorRegions(
// Don't consider the else region if it is empty.
mlir::Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(mlir::RegionSuccessor());
+ regions.push_back(
+ mlir::RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(mlir::RegionSuccessor(elseRegion));
}
@@ -4513,7 +4514,7 @@ void fir::IfOp::getEntrySuccessorRegions(
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
}
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1a33ecf8b5aa9..985573476ab78 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 66174ce0f7928..cd033c140a233 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
+
+ /// RegionBranchOpInterface
+
+ OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInits();
+ }
+
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 62e66b3dabee8..ed69287410509 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getSuccessorRegions", "getEntrySuccessorOperands"]>,
+ "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each element of the payload";
@@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d095659fc4838..4079848fd203a 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
+class OpWithFlags;
class Type;
class Value;
@@ -199,6 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
+ Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
+ Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
friend RangeBaseT;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion);
+
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..47afd252c6d68 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,16 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +192,40 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
+ /// TODO: the default value for the regionInputs is somehow broken.
+ /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : region(region), inputs(regionInputs) {}
+ : successor(region), inputs(regionInputs) {
+ assert(region && "Region must not be null");
+ }
/// Initialize a successor that branches back to/out of the parent operation.
- RegionSuccessor(Operation::result_range results)
- : inputs(ValueRange(results)) {}
- /// Constructor with no arguments.
- RegionSuccessor() : inputs(ValueRange()) {}
+ /// The target must be one of the recursive parent operations.
+ RegionSuccessor(Operation *successorOp, Operation::result_range results)
+ : successor(successorOp), inputs(ValueRange(results)) {
+ assert(successorOp && "Successor op must not be null");
+ }
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
- Region *getSuccessor() const { return region; }
+ Region *getSuccessor() const { return dyn_cast<Region *>(successor); }
/// Return true if the successor is the parent operation.
- bool isParent() const { return region == nullptr; }
+ bool isParent() const { return isa<Operation *>(successor); }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
+ bool operator==(RegionSuccessor rhs) const {
+ return successor == rhs.successor && inputs == rhs.inputs;
+ }
+
+ friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+ return !(lhs == rhs);
+ }
+
private:
- Region *region{nullptr};
+ llvm::PointerUnion<Region *, Operation *> successor{nullptr};
ValueRange inputs;
};
@@ -214,64 +233,67 @@ class RegionSuccessor {
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
- /// Creates a `RegionBranchPoint` that branches from the given region.
- /// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+ /// Creates a `RegionBranchPoint` that branches from the given terminator.
+ inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region ®ion) {
- maybeRegion = ®ion;
- return *this;
- }
-
/// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
+ bool isParent() const { return predecessor == nullptr; }
- /// Returns the region if branching from a region.
+ /// Returns the terminator if branching from a region.
/// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
+ Operation *getTerminatorPredecessorOrNull() const { return predecessor; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
+ return lhs.predecessor == rhs.predecessor;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+ constexpr RegionBranchPoint() = default;
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
+ /// op and the region terminator being branched from otherwise.
+ Operation *predecessor = nullptr;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os << "<region #"
+ << point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getTerminatorPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+ RegionBranchTerminatorOpInterface predecessor)
+ : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..94242e3ba39ce 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region operations that exhibit
+ This interface provides information for region-holding operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a region of this op or `RegionBranchPoint::parent()`. In
- the latter case, the branch originates from outside of the op, i.e., when
- first executing this op.
+ can indicate either a terminator in any of the immediately nested region of
+ this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+ originates from outside of the op, i.e., when first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op. In the former case, the region
+ either a region of this op or this op itself. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The region has two region successors: the region
- itself and the `scf.for` op. %b is an entry successor operand. %c is a
- successor operand. %a is a successor block argument. %r is a successor
- result.
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
}];
let cppNamespace = "::mlir";
@@ -162,16 +162,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `point`. `point` is guaranteed to be among the successors that are
+ to `successor`. `successor` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `point`. I.e., this op always
+ `scf.for` op, regardless of the value of `successor`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::mlir::RegionSuccessor":$successor), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +224,80 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
+ InterfaceMethod<[{
+ Returns the potential region successors when branching from any
+ terminator in `region`.
+ These are the regions that may be selected during the flow of control.
+ }],
+ "void", "getSuccessorRegions",
+ (ins "::mlir::Region&":$region,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
+ /*defaultImplementation=*/[{
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ regions);
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential branching point (predecessors) for a given successor.
+ }],
+ "void", "getPredecessors",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
...
[truncated]
|
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Mehdi Amini (joker-eph) ChangesThis is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Try to reland #161575 when the build is fixed. Patch is 116.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165429.diff 38 Files Affected:
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index d0164f32d9b6a..4f97acaa88b7a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions(
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(mlir::RegionSuccessor(getResults()));
+ regions.push_back(mlir::RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -4494,7 +4494,8 @@ void fir::IfOp::getSuccessorRegions(
// Don't consider the else region if it is empty.
mlir::Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(mlir::RegionSuccessor());
+ regions.push_back(
+ mlir::RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(mlir::RegionSuccessor(elseRegion));
}
@@ -4513,7 +4514,7 @@ void fir::IfOp::getEntrySuccessorRegions(
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
}
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1a33ecf8b5aa9..985573476ab78 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 66174ce0f7928..cd033c140a233 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
+
+ /// RegionBranchOpInterface
+
+ OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInits();
+ }
+
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 62e66b3dabee8..ed69287410509 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getSuccessorRegions", "getEntrySuccessorOperands"]>,
+ "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each element of the payload";
@@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d095659fc4838..4079848fd203a 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
+class OpWithFlags;
class Type;
class Value;
@@ -199,6 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
+ Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
+ Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
friend RangeBaseT;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion);
+
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..47afd252c6d68 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,16 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +192,40 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
+ /// TODO: the default value for the regionInputs is somehow broken.
+ /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : region(region), inputs(regionInputs) {}
+ : successor(region), inputs(regionInputs) {
+ assert(region && "Region must not be null");
+ }
/// Initialize a successor that branches back to/out of the parent operation.
- RegionSuccessor(Operation::result_range results)
- : inputs(ValueRange(results)) {}
- /// Constructor with no arguments.
- RegionSuccessor() : inputs(ValueRange()) {}
+ /// The target must be one of the recursive parent operations.
+ RegionSuccessor(Operation *successorOp, Operation::result_range results)
+ : successor(successorOp), inputs(ValueRange(results)) {
+ assert(successorOp && "Successor op must not be null");
+ }
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
- Region *getSuccessor() const { return region; }
+ Region *getSuccessor() const { return dyn_cast<Region *>(successor); }
/// Return true if the successor is the parent operation.
- bool isParent() const { return region == nullptr; }
+ bool isParent() const { return isa<Operation *>(successor); }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
+ bool operator==(RegionSuccessor rhs) const {
+ return successor == rhs.successor && inputs == rhs.inputs;
+ }
+
+ friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+ return !(lhs == rhs);
+ }
+
private:
- Region *region{nullptr};
+ llvm::PointerUnion<Region *, Operation *> successor{nullptr};
ValueRange inputs;
};
@@ -214,64 +233,67 @@ class RegionSuccessor {
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
- /// Creates a `RegionBranchPoint` that branches from the given region.
- /// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+ /// Creates a `RegionBranchPoint` that branches from the given terminator.
+ inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region ®ion) {
- maybeRegion = ®ion;
- return *this;
- }
-
/// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
+ bool isParent() const { return predecessor == nullptr; }
- /// Returns the region if branching from a region.
+ /// Returns the terminator if branching from a region.
/// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
+ Operation *getTerminatorPredecessorOrNull() const { return predecessor; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
+ return lhs.predecessor == rhs.predecessor;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+ constexpr RegionBranchPoint() = default;
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
+ /// op and the region terminator being branched from otherwise.
+ Operation *predecessor = nullptr;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os << "<region #"
+ << point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getTerminatorPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+ RegionBranchTerminatorOpInterface predecessor)
+ : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..94242e3ba39ce 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region operations that exhibit
+ This interface provides information for region-holding operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a region of this op or `RegionBranchPoint::parent()`. In
- the latter case, the branch originates from outside of the op, i.e., when
- first executing this op.
+ can indicate either a terminator in any of the immediately nested region of
+ this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+ originates from outside of the op, i.e., when first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op. In the former case, the region
+ either a region of this op or this op itself. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The region has two region successors: the region
- itself and the `scf.for` op. %b is an entry successor operand. %c is a
- successor operand. %a is a successor block argument. %r is a successor
- result.
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
}];
let cppNamespace = "::mlir";
@@ -162,16 +162,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `point`. `point` is guaranteed to be among the successors that are
+ to `successor`. `successor` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `point`. I.e., this op always
+ `scf.for` op, regardless of the value of `successor`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::mlir::RegionSuccessor":$successor), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +224,80 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
+ InterfaceMethod<[{
+ Returns the potential region successors when branching from any
+ terminator in `region`.
+ These are the regions that may be selected during the flow of control.
+ }],
+ "void", "getSuccessorRegions",
+ (ins "::mlir::Region&":$region,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
+ /*defaultImplementation=*/[{
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ regions);
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential branching point (predecessors) for a given successor.
+ }],
+ "void", "getPredecessors",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
...
[truncated]
|
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/207/builds/9130 Here is the relevant piece of the build log for the reference |
) Fix building ClangIR after RegionBranchOpInterface revamp (#165429)
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition: - A `RegionBranchPoint` is either the parent (`RegionBranchOpInterface`) op or a `RegionBranchTerminatorOpInterface` operation in a nested region. - A `RegionSuccessor` is either one of the nested region or the parent `RegionBranchOpInterface` Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Try to reland llvm#161575 ; I suspect a buildbot incremental build issue.
This change removes IREE’s local LDBG macro definition in favor of the LLVM-provided version introduced in llvm/llvm-project#143704. This update is also in preparation for llvm/llvm-project#165429, which would otherwise cause a macro redefinition error due to the existing local LDBG definition. Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
RegionBranchPointis either the parent (RegionBranchOpInterface) op or aRegionBranchTerminatorOpInterfaceoperation in a nested region.RegionSuccessoris either one of the nested region or the parentRegionBranchOpInterfaceSome new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface.
It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.
Try to reland #161575 when the build is fixed.