diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0aead9172858f..d04933423ecd0 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -492,7 +492,9 @@ struct WgToSgVectorBroadcastOp if (!layout || !layout.isForWorkgroup()) return failure(); - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + SmallVector sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); @@ -500,11 +502,15 @@ struct WgToSgVectorBroadcastOp return failure(); SmallVector newBroadcastOps; - for (auto operand : adaptor.getOperands().front()) { - auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), - newResultType, operand); - - newBroadcastOps.push_back(newBroadcast.getResult()); + auto distSource = adaptor.getOperands().front(); + int numDistributions = count / distSource.size(); + for (int i = 0; i < numDistributions; ++i) { + for (auto operand : distSource) { + auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), + newResultType, operand); + + newBroadcastOps.push_back(newBroadcast.getResult()); + } } rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); return success(); @@ -816,8 +822,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern { // Splat: single value for all subgroups Attribute singleVal = vecAttr.getSplatValue(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); - auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); - rewriter.replaceOp(op, cstOp); + SmallVector newConstOps; + for (int i = 0; i < count; ++i) { + auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); + newConstOps.push_back(cstOp); + } + rewriter.replaceOpWithMultiple(op, {newConstOps}); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index e89cb52ee02f5..e4bf3b6c3bf1d 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -116,7 +116,7 @@ gpu.module @test_round_robin_assignment { %load = xegpu.load_nd %tdesc {layout = #xegpu.layout} : !xegpu.tensor_desc<128x1xf32, #xegpu.layout> -> vector<128x1xf32> - // CHECK-COUNT-2: vector.broadcast {{.*}} : vector<16x1xf32> to vector<16x32xf32> + // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16x1xf32> to vector<16x32xf32> // CHECK-NOT: vector.broadcast %broadcast = vector.broadcast %load {layout_result_0 = #xegpu.layout} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 068dd6d865ead..320a2fb1f72ac 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -165,4 +165,30 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: splat_constant + gpu.func @splat_constant() { + // CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex> + %cst_2 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<0> : vector<8xindex> + gpu.return + } + + // CHECK-LABEL: gpu.func @step_broadcast + gpu.func @step_broadcast() { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index + // CHECK: %[[REM:.*]] = arith.remui %[[SGID]], %[[C16]] : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[STEP:.*]] = vector.step : vector<4xindex> + // CHECK: %[[BCST0:.*]] = vector.broadcast %[[C0:.*]] : index to vector<4xindex> + // CHECK: %[[ADD0:.*]] = arith.addi %[[STEP]], %[[BCST0]] : vector<4xindex> + // CHECK: %[[BCST4:.*]] = vector.broadcast %[[C4:.*]] : index to vector<4xindex> + // CHECK: %[[ADD4:.*]] = arith.addi %[[STEP]], %[[BCST4]] : vector<4xindex> + // CHECK: %[[RES0:.*]] = vector.broadcast %[[ADD0]] : vector<4xindex> to vector<16x4xindex> + // CHECK: %[[RES1:.*]] = vector.broadcast %[[ADD4]] : vector<4xindex> to vector<16x4xindex> + %2 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} : vector<8xindex> + %bcast = vector.broadcast %2 {layout_result_0 = #xegpu.layout} : vector<8xindex> to vector<256x8xindex> + gpu.return + } + }