diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index beb9b60aa9d7a..95c20b1fabe58 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1270,15 +1270,15 @@ struct WgToSgVectorTransposeOp } }; -// This pattern distributes the vector.constant_mask ops to work at subgroup -// level. -struct WgToSgVectorConstantMaskOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { +// Distribute vector mask ops to work at subgroup level. +template +struct WgToSgVectorMaskOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) @@ -1288,9 +1288,16 @@ struct WgToSgVectorConstantMaskOp VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); - ArrayRef wgMaskDimSizes = op.getMaskDimSizes(); + SmallVector wgMaskDimSizes; + if constexpr (std::is_same_v) { + for (int64_t maskSize : op.getMaskDimSizes()) { + wgMaskDimSizes.push_back( + arith::ConstantIndexOp::create(rewriter, loc, maskSize)); + } + } else if constexpr (std::is_same_v) { + wgMaskDimSizes = llvm::to_vector(op.getOperands()); + } - // Get subgroup ID. Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto sgOffsets = @@ -1302,19 +1309,17 @@ struct WgToSgVectorConstantMaskOp VectorType resultType = VectorType::get(sgShape, type.getElementType()); // In each dimension, each subgroup computes its local mask size as: - // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d]) + // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) SmallVector newCreateMaskOps; for (auto offsetSet : *sgOffsets) { SmallVector maskOperands; - for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) { - Value wgMaskSizeVal = - arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize); + for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { Value dimSizeVal = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); Value offset = offsetSet[i]; Value adjustedMaskSize = - arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset); + arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value nonNegative = arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); @@ -1335,6 +1340,8 @@ struct WgToSgVectorConstantMaskOp } }; +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp; +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp; } // namespace namespace mlir { @@ -1350,7 +1357,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, - WgToSgVectorConstantMaskOp>(patterns.getContext()); + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1477,9 +1485,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp< - vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, - vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>( + target.addDynamicallyLegalOp( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 1cddccb5fbbd1..4fb50b3b28534 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -138,5 +138,13 @@ gpu.module @test_distribution { %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout} : vector<256x128xi1> gpu.return } + + gpu.func @vector_create_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 574b365443a0a..48e93320093fd 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -583,6 +583,43 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: vector_create_mask_1D + gpu.func @vector_create_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]] + // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %cst8 = arith.constant 8 : index + %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_create_mask_2D + gpu.func @vector_create_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]] + // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]] + // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]] + // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]] + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout} : vector<256x128xi1> + gpu.return + } + // CHECK-LABEL: distribute_load_slice_attr gpu.func @distribute_load_slice_attr() { %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>