-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] [XeGPU] Add distribution pattern for vector.constant_mask from Wg To Sg #168118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/168118.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0aa6df96..81fd25a155129 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1283,6 +1283,74 @@ struct WgToSgVectorTransposeOp
}
};
+// This pattern distributes the vector.constant_mask ops to work at subgroup
+// level.
+struct WgToSgVectorConstantMaskOp
+ : public OpConversionPattern<vector::ConstantMaskOp> {
+ using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ ArrayRef<int64_t> originalMaskDimSizes = op.getMaskDimSizes();
+
+ // Get subgroup ID.
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+ SmallVector<Value> newCreateMaskOps;
+ for (auto offsetSet : *sgOffsets) {
+ SmallVector<Value> maskOperands;
+
+ for (auto [i, originalMaskSize] : llvm::enumerate(originalMaskDimSizes)) {
+ Value originalMaskSizeVal =
+ arith::ConstantIndexOp::create(rewriter, loc, originalMaskSize);
+ Value dimSizeVal =
+ arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ Value offset = offsetSet[i];
+ // Compute: originalMaskSize - offset.
+ Value adjustedMaskSize =
+ arith::SubIOp::create(rewriter, loc, originalMaskSizeVal, offset);
+ // Clamp to [0, dimSize]: max(0, min(adjustedMaskSize,
+ // dimSize))
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value clampedLow =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value clampedHigh =
+ arith::MinSIOp::create(rewriter, loc, clampedLow, dimSizeVal);
+ maskOperands.push_back(clampedHigh);
+ }
+
+ auto newCreateMaskOp =
+ vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty())
+ xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -1297,8 +1365,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
- WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
- patterns.getContext());
+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+ WgToSgVectorConstantMaskOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1425,9 +1493,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
- vector::TransposeOp, vector::BroadcastOp,
- vector::MultiDimReductionOp>(
+ target.addDynamicallyLegalOp<
+ vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
+ vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 84ce80f477a55..a752d0aa5c541 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -130,5 +130,14 @@ gpu.module @test_distribution {
%trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
gpu.return
}
+
+ // CHECK-LABEL: vector_mask_2D
+ gpu.func @vector_mask_2D() {
+ %cst16 = arith.constant 16 : index
+ // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+ // CHECK-NOT: vector.create_mask
+ %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 4fbb566cfbe73..fa08ed1623501 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -547,4 +547,41 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
gpu.return
}
+
+ // CHECK-LABEL: vector_mask_1D
+ gpu.func @vector_mask_1D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
+ // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+ // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+ // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+ %cst8 = arith.constant 8 : index
+ %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_mask_2D
+ gpu.func @vector_mask_2D() {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
+ // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
+ // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
+ // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+ // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+ // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index
+ // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+ %cst16 = arith.constant 16 : index
+ %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+ gpu.return
+ }
}
|
🐧 Linux x64 Test Results
|
ade9234 to
e162350
Compare
silee2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.