diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d7592fed6d186..9413a9296b184 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1027,6 +1027,70 @@ struct WgToSgVectorShapeCastOp } }; +/// Pattern for lowering vector.multi_reduction op to subgroup level. +/// Current limitation: the sg_layout in the reduced dimension being 1 +/// so that reduction is local to subgroup & no cross-subgroup communication is +/// needed. +/// TODO: Add cases to handle more general situations which require SLM access. +struct WgToSgMultiDimReductionOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType srcType = op.getSourceVectorType(); + VectorType dstType = dyn_cast(op.getResult().getType()); + if (!dstType) + return failure(); + + auto srcShape = srcType.getShape(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + auto reductionDims = llvm::to_vector(op.getReductionDims()); + + SmallVector sgLayout = llvm::cast(layout) + .getParent() + .getEffectiveSgLayoutAsInt(); + SmallVector sgData = llvm::cast(layout) + .getParent() + .getEffectiveSgDataAsInt(); + + // Check that the sgLayout in the reduced dimension is 1 and + // each sg gets the entire slice to reduce. + for (int64_t dim : reductionDims) { + if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim]) + return rewriter.notifyMatchFailure( + op, + "sgLayout in each reduced dimension must be 1 and sgData in the " + "reduced dim must match srcShape in that dim"); + } + + SmallVector sgShape = getSgShapeAndCount(srcShape, layout).first; + + VectorType newDstType = + VectorType::get({sgShape}, dstType.getElementType()); + + SmallVector newReductions; + for (auto sgSrc : adaptor.getSource()) { + auto newOp = rewriter.create( + op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], + op.getReductionDims()); + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); + newReductions.push_back(newOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newReductions}); + return success(); + } +}; + } // namespace namespace mlir { @@ -1040,8 +1104,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, - WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>( - patterns.getContext()); + WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, + WgToSgMultiDimReductionOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1195,6 +1259,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](vector::MultiDimReductionOp op) -> bool { + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 6ff7a94d678a3..dce73dee507e1 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -82,4 +82,20 @@ gpu.module @test_distribution { : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> gpu.return } + + // CHECK-LABEL: vector_reduce_dim_1 + gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32> + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32> + -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + -> vector<256x64xf32> + // CHECK-COUNT-2: vector.multi_reduction , {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32> + // CHECK-NOT: vector.multi_reduction + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] + : vector<256x64xf32> to vector<256xf32> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 3478a9b91da5f..48fc633974e63 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -367,6 +367,46 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: @vector_reduce_dim_0 + gpu.func @vector_reduce_dim_0(%src: memref<4x128xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<4x128xf32> + -> !xegpu.tensor_desc<4x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<4x128xf32, #xegpu.layout> + -> vector<4x128xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [0] : vector<4x4xf32> to vector<4xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] + : vector<4x128xf32> to vector<128xf32> + gpu.return + } + + // CHECK-LABEL: @vector_reduce_dim_1 + gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32> + -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + -> vector<256x64xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] + : vector<256x64xf32> to vector<256xf32> + gpu.return + } + + // CHECK-LABEL: @vector_reduce_4D + gpu.func @vector_reduce_4D(%src: ui64) { + %cst_acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [3]>} dense<0.0> : vector<4x2x6xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<4x2x6x32xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<4x2x6x32xi1> + %load = xegpu.load %src[%offset], %mask {layout_result_0 = #xegpu.layout} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [3] : vector<1x1x1x32xf16> to vector<1x1x1xf16> + %reduce = vector.multi_reduction , %load, %cst_acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [3]>} [3] + : vector<4x2x6x32xf16> to vector<4x2x6xf16> + gpu.return + } + // CHECK-LABEL: vector_step_op gpu.func @vector_step_op_slice_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index