-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Vector] Add warp distribution for scf.if
#157119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Add warp distribution for scf.if
#157119
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesThis PR adds Full diff: https://github.com/llvm/llvm-project/pull/157119.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c84eb2c9f8857..cf5928278aa64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1713,6 +1713,205 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+ WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
+ // Only pick up `IfOp` if it is the last op in the region.
+ Operation *lastNode = warpOpYield->getPrevNode();
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+ if (!ifOp)
+ return failure();
+
+ // The current `WarpOp` can yield two types of values:
+ // 1. Not results of `IfOp`:
+ // Preserve them in the new `WarpOp`.
+ // Collect their yield index.
+ // 2. Results of `IfOp`:
+ // They are not part of the new `WarpOp` results.
+ // Map current warp's yield operand index to `IfOp` result idx.
+ SmallVector<Value> nonIfYieldValues;
+ SmallVector<unsigned> nonIfYieldIndices;
+ llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+ llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+ if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+ nonIfYieldValues.push_back(yieldOperand.get());
+ nonIfYieldIndices.push_back(yieldOperandIdx);
+ continue;
+ }
+ OpResult ifResult = cast<OpResult>(yieldOperand.get());
+ const unsigned ifResultIdx = ifResult.getResultNumber();
+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
+ // If this `ifOp` result is vector type and it is yielded by the
+ // `WarpOp`, we keep track the distributed type for this result.
+ if (!isa<VectorType>(ifResult.getType()))
+ continue;
+ VectorType distType =
+ cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+ ifResultDistTypes[ifResultIdx] = distType;
+ }
+
+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+ // them
+ auto getEscapingValues = [&](Region &branch,
+ llvm::SmallSetVector<Value, 32> &values,
+ SmallVector<Type> &inputTypes,
+ SmallVector<Type> &distTypes) {
+ if (branch.empty())
+ return;
+ mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
+ Operation *parent = operand->get().getParentRegion()->getParentOp();
+ if (warpOp->isAncestor(parent)) {
+ if (!values.insert(operand->get()))
+ return;
+ Type distType = operand->get().getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(operand->get());
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ inputTypes.push_back(operand->get().getType());
+ distTypes.push_back(distType);
+ }
+ });
+ };
+ llvm::SmallSetVector<Value, 32> escapingValuesThen;
+ SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
+ SmallVector<Type> escapingValueDistTypesThen; // new warp returns
+ getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
+ escapingValueInputTypesThen, escapingValueDistTypesThen);
+ llvm::SmallSetVector<Value, 32> escapingValuesElse;
+ SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
+ SmallVector<Type> escapingValueDistTypesElse; // new warp returns
+ getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
+ escapingValueInputTypesElse, escapingValueDistTypesElse);
+
+ if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+ llvm::is_contained(escapingValueDistTypesElse, Type{}))
+ return failure();
+
+ // The new `WarpOp` groups yields values in following order:
+ // 1. Escaping values then branch
+ // 2. Escaping values else branch
+ // 3. All non-`ifOp` yielded values.
+ SmallVector<Value> newWarpOpYieldValues{escapingValuesThen.begin(),
+ escapingValuesThen.end()};
+ newWarpOpYieldValues.append(escapingValuesElse.begin(),
+ escapingValuesElse.end());
+ SmallVector<Type> newWarpOpDistTypes = escapingValueDistTypesThen;
+ newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+ escapingValueDistTypesElse.end());
+
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+ for (auto [idx, val] :
+ llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+ origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+ newWarpOpYieldValues.push_back(val);
+ newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+ }
+ // Create the new `WarpOp` with the updated yield values and types.
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+
+ // `ifOp` returns the result of the inner warp op.
+ SmallVector<Type> newIfOpDistResTypes;
+ for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+ Type distType = cast<Value>(res).getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(cast<Value>(res));
+ distType = ifResultDistTypes.count(i)
+ ? ifResultDistTypes[i]
+ : getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ newIfOpDistResTypes.push_back(distType);
+ }
+ // Create a new `IfOp` outside the new `WarpOp` region.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(),
+ newIfOpDistResTypes, ifOp.getCondition(),
+ static_cast<bool>(ifOp.thenBlock()),
+ static_cast<bool>(ifOp.elseBlock()));
+
+ auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
+ llvm::SmallSetVector<Value, 32> &escapingValues,
+ SmallVector<Type> &escapingValueInputTypes) {
+ OpBuilder::InsertionGuard g(rewriter);
+ if (!newIfBranch)
+ return;
+ rewriter.setInsertionPointToStart(newIfBranch);
+ llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
+ SmallVector<Value> innerWarpInputVals;
+ SmallVector<Type> innerWarpInputTypes;
+ for (size_t i = 0; i < escapingValues.size(); ++i) {
+ innerWarpInputVals.push_back(newWarpOp.getResult(i));
+ escapeValToBlockArgIndex[escapingValues[i]] =
+ innerWarpInputTypes.size();
+ innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
+ }
+ auto innerWarp = WarpExecuteOnLane0Op::create(
+ rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
+ newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInputVals,
+ innerWarpInputTypes);
+
+ innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
+ innerWarp.getWarpRegion().addArguments(
+ innerWarpInputTypes,
+ SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
+
+ SmallVector<Value> yieldOperands;
+ for (Value operand : oldIfBranch->getTerminator()->getOperands())
+ yieldOperands.push_back(operand);
+ rewriter.eraseOp(oldIfBranch->getTerminator());
+
+ rewriter.setInsertionPointToEnd(innerWarp.getBody());
+ gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
+ rewriter.setInsertionPointAfter(innerWarp);
+ scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
+
+ // Update any users of escaping values that were forwarded to the
+ // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
+ innerWarp.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto it = escapeValToBlockArgIndex.find(operand.get());
+ if (it == escapeValToBlockArgIndex.end())
+ continue;
+ operand.set(innerWarp.getBodyRegion().getArgument(it->second));
+ }
+ });
+ mlir::vector::moveScalarUniformCode(innerWarp);
+ };
+ processBranch(&ifOp.getThenRegion().front(),
+ &newIfOp.getThenRegion().front(), escapingValuesThen,
+ escapingValueInputTypesThen);
+ if (!ifOp.getElseRegion().empty())
+ processBranch(&ifOp.getElseRegion().front(),
+ &newIfOp.getElseRegion().front(), escapingValuesElse,
+ escapingValueInputTypesElse);
+ // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
+ // result.
+ for (auto [origIdx, newIdx] : ifResultMapping)
+ rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ newIfOp.getResult(newIdx), newIfOp);
+ // Similarly, update any users of the `WarpOp` results that were not
+ // results of the `IfOp`.
+ for (auto [origIdx, newIdx] : origToNewYieldIdx)
+ rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+ newWarpOp.getResult(newIdx));
+ // Remove the original `WarpOp` and `IfOp`, they should not have any uses
+ // at this point.
+ rewriter.eraseOp(ifOp);
+ rewriter.eraseOp(warpOp);
+ return success();
+ }
+
+private:
+ DistributionMapFn distributionMapFn;
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
@@ -2068,6 +2267,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
+ patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
+ benefit);
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8750582ef1e1f..bb7639204022f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre
// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
// CHECK-PROP-NOT: vector.broadcast
// CHECK-PROP: vector.step : vector<64xindex>
+
+// -----
+
+func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) {
+ %laneid = gpu.lane_id
+ %c0 = arith.constant 0 : index
+
+ gpu.warp_execute_on_lane_0(%laneid)[32] {
+ %seq = vector.step : vector<32xindex>
+ scf.if %pred {
+ vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
+ }
+ gpu.yield
+ }
+ return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
+// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
+// CHECK-PROP: scf.if %[[ARG1]] {
+// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
+// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>):
+// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>
+
+// -----
+
+func.func @warp_scf_if_distribute(%pred : i1) {
+ %laneid = gpu.lane_id
+ %c0 = arith.constant 0 : index
+
+ %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
+ %seq1 = vector.step : vector<32xindex>
+ %seq2 = arith.constant dense<2> : vector<32xindex>
+ %0 = scf.if %pred -> (vector<32xf32>) {
+ %1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
+ scf.yield %1 : vector<32xf32>
+ } else {
+ %2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
+ scf.yield %2 : vector<32xf32>
+ }
+ gpu.yield %0 : vector<32xf32>
+ }
+ "some_use"(%0) : (vector<1xf32>) -> ()
+
+ return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
+// CHECK-PROP-SAME: %[[ARG0:.+]]: i1
+// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
+// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id
+// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
+// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
+// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
+// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>):
+// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
+// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32>
+// CHECK-PROP: } else {
+// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
+// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
+// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
+// CHECK-PROP: return
+// CHECK-PROP: }
|
bb4aca0
to
b356d11
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall this looks great to me. awesome work.
I added some comments. I will do another pass after your response.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. please address the remainnig comments before merge.
This PR adds
scf.if
op distribution to the existingVectorDistribute
patterns.The logic mostly follows that of
scf.for
: move op outside, wrap each branch withgpu.warp_execute_on_lane_0
.A notable difference to
scf.for
is that each branch has its own set of escaping values, andscf.if
itself does not have block arguments.