diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0a9ef0aa6df96..afab880d173c7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1283,6 +1283,57 @@ struct WgToSgVectorTransposeOp } }; +/// Pattern for lowering vector.create_mask and vector.constant_mask ops to +/// subgroup level. +template +struct WgToSgVectorMaskOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResult().getType(); + ArrayRef wgShape = resultType.getShape(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + SmallVector sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + + SmallVector newMaskOps; + for (int i = 0; i < count; ++i) { + Value newMaskOp; + if constexpr (std::is_same_v) { + newMaskOp = vector::CreateMaskOp::create( + rewriter, op.getLoc(), newResultType, op.getOperands()); + } else if constexpr (std::is_same_v) { + newMaskOp = vector::ConstantMaskOp::create( + rewriter, op.getLoc(), newResultType, op.getMaskDimSizes()); + } else { + return rewriter.notifyMatchFailure(op, + "Unsupported mask operation type"); + } + xegpu::setDistributeLayoutAttr(cast(newMaskOp), + layout.dropSgLayoutAndData()); + + newMaskOps.push_back(newMaskOp); + } + + rewriter.replaceOpWithMultiple(op, {newMaskOps}); + return success(); + } +}; + +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp; +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp; + } // namespace namespace mlir { @@ -1297,7 +1348,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( patterns.getContext()); } } // namespace xegpu @@ -1427,7 +1479,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp( + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 84ce80f477a55..b587ecc726f4d 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -130,5 +130,18 @@ gpu.module @test_distribution { %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<128x256xf32> gpu.return } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + %cst16 = arith.constant 16 : index + // CHECK: %[[CST16:.*]] = arith.constant 16 : index + // CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + // CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1> + // CHECK-NOT: vector.constant_mask + %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 4fbb566cfbe73..f254b82c6401f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -547,4 +547,24 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout} : index to vector<4x1x1xindex> gpu.return } + + // CHECK-LABEL: vector_mask_1D + gpu.func @vector_mask_1D() { + %cst8 = arith.constant 8 : index + // CHECK: vector.create_mask {{.*}} : vector<16xi1> + %create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout} : vector<16xi1> + // CHECK: vector.constant_mask [8] : vector<16xi1> + %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + %cst16 = arith.constant 16 : index + // CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1> + %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + // CHECK: vector.constant_mask [16, 16] : vector<32x32xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + gpu.return + } }