-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][dataflow] Add visitBranchRegionArgument interface to SparseBackwardDataFlowAnalysis and apply it in LivenessAnalysis/RemoveDeadValues #169816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…nalysis and apply it to LivenessAnalysis.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: lonely eagle (linuxlonelyeagle) ChangesAdd visitBranchRegionArgument interface to SparseBackwardAataflowBackwardAnalysis, because the current SparseBackwardAataflowBackwardAnalysis cannot access all SSA values, such as, the loop's IV. Now we can use isitBranchRegionArgument to visit it. Apply it in LivenessAnalysis/RemoveDeadValues, solved the issue of IV liveness in the loop.,please refer to the tests added in the PR. Full diff: https://github.com/llvm/llvm-project/pull/169816.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..80d63ad5715ac 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -87,6 +87,8 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
void visitCallOperand(OpOperand &operand) override;
void setToExitState(Liveness *lattice) override;
+
+ void visitBranchRegionArgument(BlockArgument &argument) override;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 360d3c7e62000..097da72fb6bb3 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -431,6 +431,8 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
// Visit operands on branch instructions that are not forwarded.
virtual void visitBranchOperand(OpOperand &operand) = 0;
+ virtual void visitBranchRegionArgument(BlockArgument &argument) = 0;
+
// Visit operands on call instructions that are not forwarded.
virtual void visitCallOperand(OpOperand &operand) = 0;
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 20be50c8e8a5b..70ee411b03c99 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -137,7 +137,6 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
// Populating such blocks in `blocks`.
bool mayLive = false;
SmallVector<Block *, 4> blocks;
- SmallVector<BlockArgument> argumentNotOperand;
if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
if (op->getNumResults() != 0) {
// This mark value of type 1.c liveness as may live, because the region
@@ -166,25 +165,6 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
blocks.push_back(&block);
}
}
-
- // In the block of the successor block argument of RegionBranchOpInterface,
- // there may be arguments of RegionBranchOpInterface, such as the IV of
- // scf.forOp. Explicitly set this argument to live.
- for (Region ®ion : op->getRegions()) {
- SmallVector<RegionSuccessor> successors;
- regionBranchOp.getSuccessorRegions(region, successors);
- for (RegionSuccessor successor : successors) {
- if (successor.isParent())
- continue;
- auto arguments = successor.getSuccessor()->getArguments();
- ValueRange regionInputs = successor.getSuccessorInputs();
- for (auto argument : arguments) {
- if (llvm::find(regionInputs, argument) == regionInputs.end()) {
- argumentNotOperand.push_back(argument);
- }
- }
- }
- }
} else if (isa<BranchOpInterface>(op)) {
// We cannot track all successor blocks of the branch operation(More
// specifically, it's the successor's successor). Additionally, different
@@ -244,24 +224,12 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
Liveness *operandLiveness = getLatticeElement(operand.get());
LDBG() << "Marking branch operand live: " << operand.get();
propagateIfChanged(operandLiveness, operandLiveness->markLive());
- for (BlockArgument argument : argumentNotOperand) {
- Liveness *argumentLiveness = getLatticeElement(argument);
- LDBG() << "Marking RegionBranchOp's argument live: " << argument;
- // TODO: this is overly conservative: we should be able to eliminate
- // unused values in a RegionBranchOpInterface operation but that may
- // requires removing operation results which is beyond current
- // capabilities of this pass right now.
- propagateIfChanged(argumentLiveness, argumentLiveness->markLive());
- }
}
// Now that we have checked for memory-effecting ops in the blocks of concern,
// we will simply visit the op with this non-forwarded operand to potentially
// mark it "live" due to type (1.a/3) liveness.
SmallVector<Liveness *, 4> operandLiveness;
- operandLiveness.push_back(getLatticeElement(operand.get()));
- for (BlockArgument argument : argumentNotOperand)
- operandLiveness.push_back(getLatticeElement(argument));
SmallVector<const Liveness *, 4> resultsLiveness;
for (const Value result : op->getResults())
resultsLiveness.push_back(getLatticeElement(result));
@@ -303,6 +271,26 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
propagateIfChanged(operandLiveness, operandLiveness->markLive());
}
+void LivenessAnalysis::visitBranchRegionArgument(BlockArgument &blockArgument) {
+ Operation *parentOp = blockArgument.getOwner()->getParentOp();
+ LDBG() << "Visiting branch region argument: " << blockArgument
+ << "in op: " << OpWithFlags(parentOp, OpPrintingFlags().skipRegions());
+ Liveness *argumentLiveness = getLatticeElement(blockArgument);
+ SmallVector<Liveness *> parentResultsLiveness;
+ for (Value result : parentOp->getResults())
+ parentResultsLiveness.push_back(getLatticeElement(result));
+
+ for (Liveness *resultLattice : parentResultsLiveness) {
+ if (resultLattice->isLive) {
+ LDBG() << "Marking branch argument live: " << blockArgument;
+ propagateIfChanged(argumentLiveness, argumentLiveness->markLive());
+ return;
+ }
+ }
+ (void)visitOperation(parentOp, ArrayRef<Liveness *>{argumentLiveness},
+ parentResultsLiveness);
+}
+
void LivenessAnalysis::setToExitState(Liveness *lattice) {
LDBG() << "setToExitState for lattice: " << lattice;
if (lattice->isLive) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 8e63ae86753b4..d442135363392 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -599,7 +599,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
// All operands not forwarded to any successor. This set can be non-contiguous
// in the presence of multiple successors.
BitVector unaccounted(op->getNumOperands(), true);
-
+ SmallVector<BlockArgument> regionArguments;
for (RegionSuccessor &successor : successors) {
OperandRange operands = branch.getEntrySuccessorOperands(successor);
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
@@ -609,12 +609,24 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
*getLatticeElementFor(getProgramPointAfter(op), input));
unaccounted.reset(operand.getOperandNumber());
}
+
+ if (successor.isParent())
+ continue;
+ auto arguments = successor.getSuccessor()->getArguments();
+ for (BlockArgument argument : arguments) {
+ if (llvm::find(inputs, argument) == inputs.end()) {
+ regionArguments.push_back(argument);
+ }
+ }
}
// All operands not forwarded to regions are typically parameters of the
// branch operation itself (for example the boolean for if/else).
for (int index : unaccounted.set_bits()) {
visitBranchOperand(op->getOpOperand(index));
}
+ for (BlockArgument argument : regionArguments) {
+ visitBranchRegionArgument(argument);
+ }
}
void AbstractSparseBackwardDataFlowAnalysis::
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f2b0e71c9397f..5287ec03055dd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -405,6 +405,8 @@ class LayoutInfoPropagation
void visitCallOperand(OpOperand &operand) override {};
+ void visitBranchRegionArgument(BlockArgument &argument) override {};
+
void visitExternalCall(CallOpInterface call,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) override {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 4bae85dcf4f7d..2597a35c898b0 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -687,3 +687,21 @@ func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) {
// CHECK-NEXT: return
return
}
+
+// -----
+
+
+// CHECK-LABEL: func @affine_loop_no_use_iv_has_side_effect_op
+func.func @affine_loop_no_use_iv_has_side_effect_op() {
+ %c1 = arith.constant 1 : index
+ %alloc = memref.alloc() : memref<10xindex>
+ affine.for %arg0 = 0 to 79 {
+ memref.store %c1, %alloc[%c1] : memref<10xindex>
+ }
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10xindex>
+// CHECK: affine.for %[[VAL_0:.*]] = 0 to 79 {
+// CHECK: memref.store %[[C1]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<10xindex>
+// CHECK: }
+ return
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index 0bdb7c25c3b5f..c7c2e68e9d95d 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -82,6 +82,8 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
void visitCallOperand(OpOperand &operand) override;
+ void visitBranchRegionArgument(BlockArgument &argument) override {}
+
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
|
matthias-springer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar enough with the internals of the data flow analysis framework to give this a thorough review. Can someone else chime in?
| visitBranchOperand(op->getOpOperand(index)); | ||
| } | ||
| for (BlockArgument argument : regionArguments) { | ||
| visitBranchRegionArgument(argument); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we are calling visitBranchRegionArgument only for block arguments that not "control flow block arguments". E.g., for an scf.for, this would be called for the IV but not for the loop-carried variables. Is that correct? And why is that the case?
Assuming that's correct, should this function maybe be called visitNonControlFlowArguments, similar to the SparseForwardDataFlowAnalysis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we are calling visitBranchRegionArgument only for block arguments that not "control flow block arguments". E.g., for an scf.for, this would be called for the IV but not for the loop-carried variables. Is that correct? And why is that the case?
Yes, as above I say, IV is slightly different for inits/regionIters to me, as it is a property. I think we need a way to access this SSA value.
Assuming that's correct, should this function maybe be called visitNonControlFlowArguments, similar to the SparseForwardDataFlowAnalysis?
Yes. But I wanted to see more opinions from everyone, so I didn't have time to make changes to everything.
|
@joker-eph @ftynse Could you help review the code? Thank you. |
Add visitBranchRegionArgument interface to SparseBackwardDataFlowAnalysis, because the current SparseBackwardAataflowBackwardAnalysis cannot access all SSA values, such as, the loop's IV. Now we can use isitBranchRegionArgument to visit it. Apply it in LivenessAnalysis/RemoveDeadValues, solved the issue of IV liveness in the loop.,please refer to the tests added in the PR.