diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index e083507173d31..9114e37b0e42b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1161,17 +1161,6 @@ struct WgToSgVectorShapeCastOp xegpu::DistributeLayoutAttr sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0)); - auto usedByBroadcastOp = [](vector::ShapeCastOp op) { - return llvm::all_of(op.getResult().getUsers(), [](Operation *user) { - return isa(user); - }); - }; - - if (!usedByBroadcastOp(op)) - return rewriter.notifyMatchFailure( - op, "ShapeCast ops that expand unit dimensions and are used by " - "non-broadcast operations are not supported."); - if (!sourceLayout.isSliceOf(layout)) return rewriter.notifyMatchFailure( op, "The ShapeCast op only expands dimensions, the input layout " diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 3bc43b780ade2..c3eb59adee2a6 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -960,4 +960,28 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: @shape_cast_used_by_elementwise + gpu.func @shape_cast_used_by_elementwise(%dst: memref<1x1x16xf32>) { + // Regression test: shape_cast expanding unit dimensions can be used by elementwise ops + // This previously failed with "ShapeCast ops that expand unit dimensions and are used by + // non-broadcast operations are not supported." + + // CHECK: vector.step : vector<16xindex> + // CHECK: vector.shape_cast {{.*}} : vector<16xindex> to vector<1x1x16xindex> + // CHECK: arith.addi {{.*}} : vector<1x1x16xindex> + // CHECK: xegpu.store {{.*}} : vector<1x1x16xf32>, i64, vector<1x1x16xindex>, vector<1x1x16xi1> + %step = vector.step : vector<16xindex> + %shape_cast = vector.shape_cast %step : vector<16xindex> to vector<1x1x16xindex> + %cst = arith.constant dense<10> : vector<1x1x16xindex> + %add = arith.addi %shape_cast, %cst : vector<1x1x16xindex> + + %cst_val = arith.constant dense<1.0> : vector<1x1x16xf32> + %intptr = memref.extract_aligned_pointer_as_index %dst : memref<1x1x16xf32> -> index + %ptr = arith.index_cast %intptr : index to i64 + %mask = arith.constant dense : vector<1x1x16xi1> + + xegpu.store %cst_val, %ptr[%add], %mask {layout = #xegpu.layout} : vector<1x1x16xf32>, i64, vector<1x1x16xindex>, vector<1x1x16xi1> + gpu.return + } + }