Skip to content

Conversation

akroviakov
Copy link
Contributor

@akroviakov akroviakov commented Oct 2, 2025

This PR improves the warp distribution robustness by:

  1. Ensuring that during the warp result deduplication, results with no uses are not mapped to a non-existing index. Currently we map to newResultTypes.size(), but may opt out of inserting to it, leading to a later OOB error.
  2. Simplifying the scf.if and scf.for handling through the usage of moveRegionToNewWarpOpAndAppendReturns, which also performs warp result deduplication in-place. This allows avoiding cases where, for example, after sinking two scf.if that need the same escaping value, a higher-ranked sink-pattern tries to sink the escaping value producer (which is yielded twice at this point) prior to WarpOpDeadResult actually deduplicates the result, leading to sinking the same op twice (once per yield operand).

@akroviakov
Copy link
Contributor Author

@charithaintc @adam-smnk

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR improves the warp distribution robustness by:

  1. Ensuring that during the warp result deduplication, results with no uses are not mapped to a non-existing index. Currently we map to newResultTypes.size(), but may opt out of inserting to it, leading to a later OOB error.
  2. Simplifying the scf.if and scf.for handling through the usage of moveRegionToNewWarpOpAndAppendReturns, which also performs warp result deduplication in-place. This allows avoiding cases where, for example, after sinking two scf.if that need the same escaping value, a higher-ranked sink-pattern tries to lower the escaping value producer (which is yielded twice at this point) prior to WarpOpDeadResult actually deduplicates the result, leading to sinking the same op twice (once per yield operand).

Full diff: https://github.com/llvm/llvm-project/pull/161647.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+20-39)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+19)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..47aa1ca40fb03 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -934,11 +934,13 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
     //   3. skipping from the new result types / new yielded values any result
     //      that has no use or whose yielded value has already been seen.
     for (OpResult result : warpOp.getResults()) {
+      if (result.use_empty())
+        continue;
       Value yieldOperand = yield.getOperand(result.getResultNumber());
       auto it = dedupYieldOperandPositionMap.insert(
           std::make_pair(yieldOperand, newResultTypes.size()));
       dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
-      if (result.use_empty() || !it.second)
+      if (!it.second)
         continue;
       newResultTypes.push_back(result.getType());
       newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1845,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
                               escapingValueDistTypesElse.end());
 
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
     for (auto [idx, val] :
          llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
-      origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(val);
       newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
     }
-    // Create the new `WarpOp` with the updated yield values and types.
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    // Replace the old `WarpOp` with the new one that has additional yield
+    // values and types.
+    SmallVector<size_t> newIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
     // `ifOp` returns the result of the inner warp op.
     SmallVector<Type> newIfOpDistResTypes;
     for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1872,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newIfOp = scf::IfOp::create(
-        rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
-        static_cast<bool>(ifOp.thenBlock()),
+        rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+        newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
         static_cast<bool>(ifOp.elseBlock()));
     auto encloseRegionInWarpOp =
         [&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1890,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
           for (size_t i = 0; i < escapingValues.size();
                ++i, ++warpResRangeStart) {
             innerWarpInputVals.push_back(
-                newWarpOp.getResult(warpResRangeStart));
+                newWarpOp.getResult(newIndices[warpResRangeStart]));
             escapeValToBlockArgIndex[escapingValues[i]] =
                 innerWarpInputTypes.size();
             innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1938,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
     // result.
     for (auto [origIdx, newIdx] : ifResultMapping)
-      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+      rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newIfOp.getResult(newIdx), newIfOp);
-    // Similarly, update any users of the `WarpOp` results that were not
-    // results of the `IfOp`.
-    for (auto [origIdx, newIdx] : origToNewYieldIdx)
-      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
-                                  newWarpOp.getResult(newIdx));
-    // Remove the original `WarpOp` and `IfOp`, they should not have any uses
-    // at this point.
-    rewriter.eraseOp(ifOp);
-    rewriter.eraseOp(warpOp);
     return success();
   }
 
@@ -2065,19 +2058,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                               escapingValueDistTypes.begin(),
                               escapingValueDistTypes.end());
     // Next, we insert all non-`ForOp` yielded values and their distributed
-    // types. We also create a mapping between the non-`ForOp` yielded value
-    // index and the corresponding new `WarpOp` yield value index (needed to
-    // update users later).
-    llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+    // types.
     for (auto [i, v] :
          llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
-      nonForResultMapping[i] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(v);
       newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
     }
     // Create the new `WarpOp` with the updated yield values and types.
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    SmallVector<size_t> newIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
@@ -2086,7 +2076,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
     for (size_t i = 0; i < escapingValuesStartIdx; ++i)
-      newForOpOperands.push_back(newWarpOp.getResult(i));
+      newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
 
     // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2100,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (size_t i = escapingValuesStartIdx;
          i < escapingValuesStartIdx + escapingValues.size(); ++i) {
-      innerWarpInput.push_back(newWarpOp.getResult(i));
+      innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
       argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
           innerWarpInputType.size();
       innerWarpInputType.push_back(
@@ -2146,20 +2136,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     if (!innerWarp.getResults().empty())
       scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
 
-    // Update the users of original `WarpOp` results that were coming from the
+    // Update the users of the new `WarpOp` results that were coming from the
     // original `ForOp` to the corresponding new `ForOp` result.
     for (auto [origIdx, newIdx] : forResultMapping)
-      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+      rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newForOp.getResult(newIdx), newForOp);
-    // Similarly, update any users of the `WarpOp` results that were not
-    // results of the `ForOp`.
-    for (auto [origIdx, newIdx] : nonForResultMapping)
-      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
-                                  newWarpOp.getResult(newIdx));
-    // Remove the original `WarpOp` and `ForOp`, they should not have any uses
-    // at this point.
-    rewriter.eraseOp(forOp);
-    rewriter.eraseOp(warpOp);
     // Update any users of escaping values that were forwarded to the
     // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
     newForOp.walk([&](Operation *op) {
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..401cdd29b281c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1)  {
 //       CHECK-PROP:    "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
 //       CHECK-PROP:    return
 //       CHECK-PROP:  }
+
+// -----
+func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) {
+  %r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] ->
+    (vector<1xf32>, vector<2xf32>, vector<1xf32>) {
+    %2 = "some_def"() : () -> (vector<32xf32>)
+    %3 = "some_def"() : () -> (vector<64xf32>)
+    gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32>
+  }
+  %r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>)
+  return %r0 : vector<1xf32>
+}
+
+// CHECK-PROP: func @dedup_unused_result
+// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>)
+// CHECK-PROP:   %[[Y0:.*]] = "some_def"() : () -> vector<32xf32>
+// CHECK-PROP:   %[[Y1:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP:   gpu.yield %[[Y0]] : vector<32xf32>
+// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants