diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 33d4b0457e5d3..c6ace1802bc43 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1285,6 +1285,71 @@ struct WgToSgVectorTransposeOp } }; +// This pattern distributes the vector.constant_mask ops to work at subgroup +// level. +struct WgToSgVectorConstantMaskOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + + ArrayRef wgMaskDimSizes = op.getMaskDimSizes(); + + // Get subgroup ID. + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType resultType = VectorType::get(sgShape, type.getElementType()); + + // In each dimension, each subgroup computes its local mask size as: + // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d]) + SmallVector newCreateMaskOps; + for (auto offsetSet : *sgOffsets) { + SmallVector maskOperands; + + for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) { + Value wgMaskSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize); + Value dimSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + Value offset = offsetSet[i]; + Value adjustedMaskSize = + arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value sgMaskSize = + arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); + maskOperands.push_back(sgMaskSize); + } + + auto newCreateMaskOp = + vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); + xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0), + layout.dropSgLayoutAndData()); + newCreateMaskOps.push_back(newCreateMaskOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); + return success(); + } +}; + } // namespace namespace mlir { @@ -1299,8 +1364,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( - patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1427,9 +1492,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp< + vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, + vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 84ce80f477a55..1cddccb5fbbd1 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -130,5 +130,13 @@ gpu.module @test_distribution { %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<128x256xf32> gpu.return } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 5dde84e8e0bc2..574b365443a0a 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -548,6 +548,41 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: vector_mask_1D + gpu.func @vector_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]] + // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]] + // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]] + // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]] + // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]] + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + gpu.return + } + // CHECK-LABEL: distribute_load_slice_attr gpu.func @distribute_load_slice_attr() { %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>