diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 24e909548fe0b..f9aa28d5203db 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -113,9 +113,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef shape, if (layout.size() != shape.size()) return std::nullopt; auto ratio = computeShapeRatio(shape, layout); - if (!ratio.has_value()) + if (ratio.has_value()) { + newShape = ratio.value(); + } else if (!rr || !computeShapeRatio(layout, shape).has_value()) { return std::nullopt; - newShape = ratio.value(); + } + // Round-robin case: continue with original newShape } if (data.size()) { diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 742d11f8052ec..52acde4dffc2e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -527,4 +527,11 @@ gpu.module @test_distribution { %cst_1 = arith.constant {layout_result_0 = #xegpu.layout} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> gpu.return } + + // CHECK-LABEL: scalar_broadcast + gpu.func @scalar_broadcast(%arg0: index) { + // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1xindex> + %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout} : index to vector<4x1x1xindex> + gpu.return + } }