[MLIR][Vector] Fix WarpOpScfForOp and WarpOpScfIfOp leaving invalid ops after region moves#188951
Conversation
|
@llvm/pr-subscribers-mlir-vector Author: Mehdi Amini (joker-eph) ChangesWarpOpScfForOp::matchAndRewrite called mergeBlocks() to move forOp's body block into the inner WarpOp. mergeBlocks() erases the source block, leaving forOp with an empty body region (0 blocks). Since scf.for requires exactly 1 body block, IR verification fails with "region with 1 blocks" after the pattern succeeds. Additionally, when forOp had no init args, the pattern was missing the scf.yield terminator in the new ForOp. WarpOpScfIfOp::matchAndRewrite had the same issue: takeBody() emptied the ifOp's then/else regions, leaving scf.if with 0 blocks. Fix:
Assisted-by: Claude Code Full diff: https://github.com/llvm/llvm-project/pull/188951.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b4d500212c770..31d875e3a67de 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
@@ -2000,6 +2001,21 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (auto [origIdx, newIdx] : ifResultMapping)
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
+
+ // The original `ifOp` was left inside `newWarpOp` with empty then/else
+ // regions (their blocks were moved into the inner WarpOps by takeBody).
+ // Replace its results with poison and erase it to restore IR validity.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(ifOp);
+ for (OpResult result : ifOp.getResults()) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, ifOp.getLoc(), result.getType());
+ rewriter.replaceAllUsesWith(result, poison);
+ }
+ rewriter.eraseOp(ifOp);
+ }
+
return success();
}
@@ -2215,6 +2231,21 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
+
+ // The original `ForOp` was left inside `newWarpOp` with an empty body
+ // region (its body block was moved into `innerWarp` by `mergeBlocks`).
+ // Replace its results with poison and erase it to restore IR validity.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(forOp);
+ for (OpResult result : forOp.getResults()) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, forOp.getLoc(), result.getType());
+ rewriter.replaceAllUsesWith(result, poison);
+ }
+ rewriter.eraseOp(forOp);
+ }
+
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
|
|
@llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesWarpOpScfForOp::matchAndRewrite called mergeBlocks() to move forOp's body block into the inner WarpOp. mergeBlocks() erases the source block, leaving forOp with an empty body region (0 blocks). Since scf.for requires exactly 1 body block, IR verification fails with "region with 1 blocks" after the pattern succeeds. Additionally, when forOp had no init args, the pattern was missing the scf.yield terminator in the new ForOp. WarpOpScfIfOp::matchAndRewrite had the same issue: takeBody() emptied the ifOp's then/else regions, leaving scf.if with 0 blocks. Fix:
Assisted-by: Claude Code Full diff: https://github.com/llvm/llvm-project/pull/188951.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b4d500212c770..31d875e3a67de 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
@@ -2000,6 +2001,21 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (auto [origIdx, newIdx] : ifResultMapping)
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
+
+ // The original `ifOp` was left inside `newWarpOp` with empty then/else
+ // regions (their blocks were moved into the inner WarpOps by takeBody).
+ // Replace its results with poison and erase it to restore IR validity.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(ifOp);
+ for (OpResult result : ifOp.getResults()) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, ifOp.getLoc(), result.getType());
+ rewriter.replaceAllUsesWith(result, poison);
+ }
+ rewriter.eraseOp(ifOp);
+ }
+
return success();
}
@@ -2215,6 +2231,21 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
+
+ // The original `ForOp` was left inside `newWarpOp` with an empty body
+ // region (its body block was moved into `innerWarp` by `mergeBlocks`).
+ // Replace its results with poison and erase it to restore IR validity.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(forOp);
+ for (OpResult result : forOp.getResults()) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, forOp.getLoc(), result.getType());
+ rewriter.replaceAllUsesWith(result, poison);
+ }
+ rewriter.eraseOp(forOp);
+ }
+
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
|
…ps after region moves WarpOpScfForOp::matchAndRewrite called mergeBlocks() to move forOp's body block into the inner WarpOp. mergeBlocks() erases the source block, leaving forOp with an empty body region (0 blocks). Since scf.for requires exactly 1 body block, IR verification fails with "region with 1 blocks" after the pattern succeeds. Additionally, when forOp had no init args, the pattern was missing the scf.yield terminator in the new ForOp. WarpOpScfIfOp::matchAndRewrite had the same issue: takeBody() emptied the ifOp's then/else regions, leaving scf.if with 0 blocks. Fix: - Restore the conditional scf.yield creation (only when newForOp has results). - After merging/taking the regions, replace the remaining op's results with ub.poison and erase the now-invalid op from the new WarpOp's body. Assisted-by: Claude Code Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
5f0ca70 to
d6f25e1
Compare
|
Ping @matthias-springer , @banach-space @dcaballe ? |
|
LGTM but I think @Groverkss is the right one to approve this change |
matthias-springer
left a comment
There was a problem hiding this comment.
Vector warp distribution started as an experiment for GPU codegen in IREE. I believe this approach did not turn out to be flexible enough and was abandoned. Is anyone using this piece of code? Maybe it can be deleted entirely.
|
Excellent question @matthias-springer , let's poke @akroviakov and @charithaintc @AGindinson who touched on this recently. Is this useful or is it a candidate for deletion? |
|
We (@Jianhui-Li @charithaintc) have no objection to its removal. |
|
@matthias-springer @joker-eph As we've discussed with @Groverkss in #183830, dropping this entirely would be the right long-term direction. However, I would like to address iree-org/iree#23624 first, because otherwise the IREE SPIR-V pipeline is reliant on the patterns, and replacing them hastily would cause end-to-end regressions. I would see to it to drop the warp distribution patterns afterwards, if this is fine with everyone. |
|
We recently (last year?) moved away from upstream Vector warp distribution in IREE for some backends. We need to replace it for SPIR-V as pointed out by @AGindinson . I think I have seen the XeGPU folks working on this part now. I'm happy to review this though, but I agree eventually this path should be dropped. |
| for (auto [origIdx, newIdx] : ifResultMapping) | ||
| rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), | ||
| newIfOp.getResult(newIdx), newIfOp); | ||
|
|
||
| // The original `ifOp` was left inside `newWarpOp` with empty then/else | ||
| // regions (their blocks were moved into the inner WarpOps by takeBody). | ||
| // Clear remaining uses and erase it to restore IR validity. Directly | ||
| // update newWarpOp's yield operands instead of using replaceAllUsesWith, | ||
| // to avoid triggering notifyOperandReplaced on the now-invalid ifOp. | ||
| { | ||
| OpBuilder::InsertionGuard guard(rewriter); | ||
| rewriter.setInsertionPoint(ifOp); | ||
| Operation *yield = newWarpOp.getTerminator(); | ||
| rewriter.modifyOpInPlace(yield, [&]() { | ||
| for (auto [origIdx, ifResultIdx] : ifResultMapping) { | ||
| Value poison = ub::PoisonOp::create( | ||
| rewriter, ifOp.getLoc(), ifOp.getResult(ifResultIdx).getType()); | ||
| yield->setOperand(origIdx, poison); | ||
| } | ||
| }); | ||
| rewriter.eraseOp(ifOp); | ||
| } |
There was a problem hiding this comment.
Then bit before already replaces all uses. do we need to replace all these uses? Can we not just erase the op?
There was a problem hiding this comment.
Maybe I'm wrong, but in the previous loop iteration we replace all warp op results with the newIfOp results, and then we are using the same mapping to convert them to poison? I'm a bit confused what is happening here.
There was a problem hiding this comment.
Here is the state of the IR at this point, we go from:
%1 = gpu.warp_execute_on_lane_0(%0)[32] -> (vector<1xf32>) {
%2 = vector.step : vector<32xindex>
%3 = scf.if %arg0 -> (vector<32xf32>) {
%4 = "some_op"(%2) : (vector<32xindex>) -> vector<32xf32>
scf.yield %4 : vector<32xf32>
} else {
%4 = "other_op"(%cst) : (vector<32xindex>) -> vector<32xf32>
scf.yield %4 : vector<32xf32>
}
gpu.yield %3 : vector<32xf32>
}
To:
%3:3 = "gpu.warp_execute_on_lane_0"(%2) <{warp_size = 32 : i64}> ({
%9 = "vector.step"() : () -> vector<32xindex>
%10 = "scf.if"(%arg0) ({
}, {
}) : (i1) -> vector<32xf32>
"gpu.yield"(%10, %arg0, %9) : (vector<32xf32>, i1, vector<32xindex>) -> ()
}) : (index) -> (vector<1xf32>, i1, vector<1xindex>)
| // The original `ForOp` was left inside `newWarpOp` with an empty body | ||
| // region (its body block was moved into `innerWarp` by `mergeBlocks`). | ||
| // Clear remaining uses and erase it to restore IR validity. | ||
| for (OpResult result : forOp.getResults()) { | ||
| if (forResultsMapped.test(result.getResultNumber())) | ||
| rewriter.replaceAllUsesWith( | ||
| result, forOp.getInitArgs()[result.getResultNumber()]); | ||
| } |
There was a problem hiding this comment.
Do we need to clear remaining uses?
There was a problem hiding this comment.
(Same concept as above, let me know if you need the example)
Hi, @joker-eph we have moved away from this approach due to slow compilation times. |
|
Since we can't just yet remove it from the codebase, seems like valuable to land this PR. I trust you all will follow-up with deleting the pass in the future? |
Yes, we have a tracking issue in IREE to remove this., I will try to get to this quick and then send a follow up to delete this in future. |
WarpOpScfForOp::matchAndRewrite called mergeBlocks() to move forOp's body block into the inner WarpOp. mergeBlocks() erases the source block, leaving forOp with an empty body region (0 blocks). Since scf.for requires exactly 1 body block, IR verification fails with "region with 1 blocks" after the pattern succeeds. Additionally, when forOp had no init args, the pattern was missing the scf.yield terminator in the new ForOp.
WarpOpScfIfOp::matchAndRewrite had the same issue: takeBody() emptied the ifOp's then/else regions, leaving scf.if with 0 blocks.
Fix:
Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.