Skip to content

Conversation

akroviakov
Copy link
Contributor

This PR adds scf.if op distribution to the existing VectorDistribute patterns.
The logic mostly follows that of scf.for: move op outside, wrap each branch with gpu.warp_execute_on_lane_0.
A notable difference to scf.for is that each branch has its own set of escaping values, and scf.if itself does not have block arguments.

@llvmbot
Copy link
Member

llvmbot commented Sep 5, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR adds scf.if op distribution to the existing VectorDistribute patterns.
The logic mostly follows that of scf.for: move op outside, wrap each branch with gpu.warp_execute_on_lane_0.
A notable difference to scf.for is that each branch has its own set of escaping values, and scf.if itself does not have block arguments.


Full diff: https://github.com/llvm/llvm-project/pull/157119.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+201)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+69)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c84eb2c9f8857..cf5928278aa64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1713,6 +1713,205 @@ struct WarpOpInsert : public WarpDistributionPattern {
   }
 };
 
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+  WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
+    // Only pick up `IfOp` if it is the last op in the region.
+    Operation *lastNode = warpOpYield->getPrevNode();
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+    if (!ifOp)
+      return failure();
+
+    // The current `WarpOp` can yield two types of values:
+    // 1. Not results of `IfOp`:
+    //     Preserve them in the new `WarpOp`.
+    //     Collect their yield index.
+    // 2. Results of `IfOp`:
+    //     They are not part of the new `WarpOp` results.
+    //     Map current warp's yield operand index to `IfOp` result idx.
+    SmallVector<Value> nonIfYieldValues;
+    SmallVector<unsigned> nonIfYieldIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+    llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+      const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+      if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+        nonIfYieldValues.push_back(yieldOperand.get());
+        nonIfYieldIndices.push_back(yieldOperandIdx);
+        continue;
+      }
+      OpResult ifResult = cast<OpResult>(yieldOperand.get());
+      const unsigned ifResultIdx = ifResult.getResultNumber();
+      ifResultMapping[yieldOperandIdx] = ifResultIdx;
+      // If this `ifOp` result is vector type and it is yielded by the
+      // `WarpOp`, we keep track the distributed type for this result.
+      if (!isa<VectorType>(ifResult.getType()))
+        continue;
+      VectorType distType =
+          cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+      ifResultDistTypes[ifResultIdx] = distType;
+    }
+
+    // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+    // them
+    auto getEscapingValues = [&](Region &branch,
+                                 llvm::SmallSetVector<Value, 32> &values,
+                                 SmallVector<Type> &inputTypes,
+                                 SmallVector<Type> &distTypes) {
+      if (branch.empty())
+        return;
+      mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
+        Operation *parent = operand->get().getParentRegion()->getParentOp();
+        if (warpOp->isAncestor(parent)) {
+          if (!values.insert(operand->get()))
+            return;
+          Type distType = operand->get().getType();
+          if (auto vecType = dyn_cast<VectorType>(distType)) {
+            AffineMap map = distributionMapFn(operand->get());
+            distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+          }
+          inputTypes.push_back(operand->get().getType());
+          distTypes.push_back(distType);
+        }
+      });
+    };
+    llvm::SmallSetVector<Value, 32> escapingValuesThen;
+    SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
+    SmallVector<Type> escapingValueDistTypesThen;  // new warp returns
+    getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
+                      escapingValueInputTypesThen, escapingValueDistTypesThen);
+    llvm::SmallSetVector<Value, 32> escapingValuesElse;
+    SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
+    SmallVector<Type> escapingValueDistTypesElse;  // new warp returns
+    getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
+                      escapingValueInputTypesElse, escapingValueDistTypesElse);
+
+    if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+        llvm::is_contained(escapingValueDistTypesElse, Type{}))
+      return failure();
+
+    // The new `WarpOp` groups yields values in following order:
+    // 1. Escaping values then branch
+    // 2. Escaping values else branch
+    // 3. All non-`ifOp` yielded values.
+    SmallVector<Value> newWarpOpYieldValues{escapingValuesThen.begin(),
+                                            escapingValuesThen.end()};
+    newWarpOpYieldValues.append(escapingValuesElse.begin(),
+                                escapingValuesElse.end());
+    SmallVector<Type> newWarpOpDistTypes = escapingValueDistTypesThen;
+    newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+                              escapingValueDistTypesElse.end());
+
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+    for (auto [idx, val] :
+         llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+      origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(val);
+      newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+    }
+    // Create the new `WarpOp` with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+
+    // `ifOp` returns the result of the inner warp op.
+    SmallVector<Type> newIfOpDistResTypes;
+    for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+      Type distType = cast<Value>(res).getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(cast<Value>(res));
+        distType = ifResultDistTypes.count(i)
+                       ? ifResultDistTypes[i]
+                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newIfOpDistResTypes.push_back(distType);
+    }
+    // Create a new `IfOp` outside the new `WarpOp` region.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(),
+                                     newIfOpDistResTypes, ifOp.getCondition(),
+                                     static_cast<bool>(ifOp.thenBlock()),
+                                     static_cast<bool>(ifOp.elseBlock()));
+
+    auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
+                             llvm::SmallSetVector<Value, 32> &escapingValues,
+                             SmallVector<Type> &escapingValueInputTypes) {
+      OpBuilder::InsertionGuard g(rewriter);
+      if (!newIfBranch)
+        return;
+      rewriter.setInsertionPointToStart(newIfBranch);
+      llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
+      SmallVector<Value> innerWarpInputVals;
+      SmallVector<Type> innerWarpInputTypes;
+      for (size_t i = 0; i < escapingValues.size(); ++i) {
+        innerWarpInputVals.push_back(newWarpOp.getResult(i));
+        escapeValToBlockArgIndex[escapingValues[i]] =
+            innerWarpInputTypes.size();
+        innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
+      }
+      auto innerWarp = WarpExecuteOnLane0Op::create(
+          rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
+          newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInputVals,
+          innerWarpInputTypes);
+
+      innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
+      innerWarp.getWarpRegion().addArguments(
+          innerWarpInputTypes,
+          SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
+
+      SmallVector<Value> yieldOperands;
+      for (Value operand : oldIfBranch->getTerminator()->getOperands())
+        yieldOperands.push_back(operand);
+      rewriter.eraseOp(oldIfBranch->getTerminator());
+
+      rewriter.setInsertionPointToEnd(innerWarp.getBody());
+      gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
+      rewriter.setInsertionPointAfter(innerWarp);
+      scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
+
+      // Update any users of escaping values that were forwarded to the
+      // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
+      innerWarp.walk([&](Operation *op) {
+        for (OpOperand &operand : op->getOpOperands()) {
+          auto it = escapeValToBlockArgIndex.find(operand.get());
+          if (it == escapeValToBlockArgIndex.end())
+            continue;
+          operand.set(innerWarp.getBodyRegion().getArgument(it->second));
+        }
+      });
+      mlir::vector::moveScalarUniformCode(innerWarp);
+    };
+    processBranch(&ifOp.getThenRegion().front(),
+                  &newIfOp.getThenRegion().front(), escapingValuesThen,
+                  escapingValueInputTypesThen);
+    if (!ifOp.getElseRegion().empty())
+      processBranch(&ifOp.getElseRegion().front(),
+                    &newIfOp.getElseRegion().front(), escapingValuesElse,
+                    escapingValueInputTypesElse);
+    // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
+    // result.
+    for (auto [origIdx, newIdx] : ifResultMapping)
+      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+                                    newIfOp.getResult(newIdx), newIfOp);
+    // Similarly, update any users of the `WarpOp` results that were not
+    // results of the `IfOp`.
+    for (auto [origIdx, newIdx] : origToNewYieldIdx)
+      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+                                  newWarpOp.getResult(newIdx));
+    // Remove the original `WarpOp` and `IfOp`, they should not have any uses
+    // at this point.
+    rewriter.eraseOp(ifOp);
+    rewriter.eraseOp(warpOp);
+    return success();
+  }
+
+private:
+  DistributionMapFn distributionMapFn;
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't
 /// change the order of execution. This creates a new scf.for region after the
@@ -2068,6 +2267,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
+  patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
+                              benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8750582ef1e1f..bb7639204022f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre
 // CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
 // CHECK-PROP-NOT: vector.broadcast
 // CHECK-PROP: vector.step : vector<64xindex>
+
+// -----
+
+func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1)  {
+  %laneid = gpu.lane_id
+  %c0 = arith.constant 0 : index
+
+  gpu.warp_execute_on_lane_0(%laneid)[32] {
+    %seq = vector.step : vector<32xindex>
+    scf.if %pred {
+      vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
+    }
+    gpu.yield
+  }
+  return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
+//       CHECK-PROP:   scf.if %[[ARG1]] {
+//       CHECK-PROP:   gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
+//       CHECK-PROP:   ^bb0(%[[ARG2:.+]]: vector<32xindex>):
+//       CHECK-PROP:   vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>
+
+// -----
+
+func.func @warp_scf_if_distribute(%pred : i1)  {
+  %laneid = gpu.lane_id
+  %c0 = arith.constant 0 : index
+
+  %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
+    %seq1 = vector.step : vector<32xindex>
+    %seq2 = arith.constant dense<2> : vector<32xindex>
+    %0 = scf.if %pred -> (vector<32xf32>) {
+      %1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
+      scf.yield %1 : vector<32xf32>
+    } else {
+      %2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
+      scf.yield %2 : vector<32xf32>
+    }
+    gpu.yield %0 : vector<32xf32>
+  }
+  "some_use"(%0) : (vector<1xf32>) -> ()
+
+  return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
+//  CHECK-PROP-SAME:    %[[ARG0:.+]]: i1
+//       CHECK-PROP:    %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
+//       CHECK-PROP:    %[[LANE_ID:.+]] = gpu.lane_id
+//       CHECK-PROP:    %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
+//       CHECK-PROP:    %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
+//       CHECK-PROP:    %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
+//       CHECK-PROP:        ^bb0(%[[ARG1:.+]]: vector<32xindex>):
+//       CHECK-PROP:        %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
+//       CHECK-PROP:        gpu.yield %{{.*}} : vector<32xf32>
+//       CHECK-PROP:      }
+//       CHECK-PROP:      scf.yield %[[THEN_DIST]] : vector<1xf32>
+//       CHECK-PROP:    } else {
+//       CHECK-PROP:      %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
+//       CHECK-PROP:        %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
+//       CHECK-PROP:        gpu.yield %{{.*}} : vector<32xf32>
+//       CHECK-PROP:      }
+//       CHECK-PROP:      scf.yield %[[ELSE_DIST]] : vector<1xf32>
+//       CHECK-PROP:    }
+//       CHECK-PROP:    "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
+//       CHECK-PROP:    return
+//       CHECK-PROP:  }

@Garra1980
Copy link

cc @charithaintc

@akroviakov akroviakov force-pushed the akroviak/scf-if-distribute branch from bb4aca0 to b356d11 Compare September 5, 2025 16:55
@charithaintc charithaintc self-requested a review September 8, 2025 22:58
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall this looks great to me. awesome work.

I added some comments. I will do another pass after your response.

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. please address the remainnig comments before merge.

@charithaintc charithaintc self-requested a review September 10, 2025 19:37
@charithaintc charithaintc merged commit 7f007b5 into llvm:main Sep 10, 2025
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants