From c4a2f64e4516d450b45061fc24851d1ae4545423 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 13 Nov 2025 09:19:12 +0000 Subject: [PATCH] [XeGPU] Use DistributeLayoutAttr instead of LayoutAttr for load gather/scatter ops Signed-off-by: dchigarev --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 8 ++++---- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ++-- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 4 ++-- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 12 +++++++----- .../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 17 +++++++++++++++++ 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 689ebd0d1179a..4c67856b559b1 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -844,7 +844,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { OptionalAttr:$l1_hint, OptionalAttr:$l2_hint, OptionalAttr:$l3_hint, - OptionalAttr:$layout); + OptionalAttr:$layout); let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value); let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -903,7 +903,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint, - "xegpu::LayoutAttr": $layout)> + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; @@ -988,7 +988,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { OptionalAttr:$l1_hint, OptionalAttr:$l2_hint, OptionalAttr:$l3_hint, - OptionalAttr:$layout); + OptionalAttr:$layout); let extraClassDeclaration = extraBaseClassDeclaration#[{ Type getDestType() { @@ -1046,7 +1046,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint, - "xegpu::LayoutAttr": $layout)> + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 4dd10bedc6d84..85c9a966f0fe8 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -901,7 +901,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint, - xegpu::LayoutAttr layout) { + DistributeLayoutAttr layout) { auto loc = source.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); @@ -985,7 +985,7 @@ void StoreScatterOp::build( OpBuilder &builder, OperationState &state, Value value, Value dest, ArrayRef offsets, Value mask, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) { + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { auto loc = dest.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index c3bf9606693a8..330553564f81a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -678,7 +678,7 @@ struct UnrollLoadGatherOpWithOffset pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } - auto layout = dyn_cast_if_present(op.getLayoutAttr()); + auto layout = op.getLayoutAttr(); if (layout) layout = layout.dropInstData(); @@ -778,7 +778,7 @@ struct UnrollStoreScatterOpWithOffsets SmallVector convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); - auto layout = dyn_cast_if_present(op.getLayoutAttr()); + auto layout = op.getLayoutAttr(); if (layout) layout = layout.dropInstData(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0a9ef0aa6df96..33d4b0457e5d3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -889,8 +889,8 @@ struct WgToSgLoadGatherOpWithOffset return failure(); ArrayRef wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = dyn_cast_if_present( - xegpu::getDistributeLayoutAttr(op.getResult())); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -913,10 +913,12 @@ struct WgToSgLoadGatherOpWithOffset VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLayout = layout.dropSgLayoutAndData(); auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), - layout.dropSgLayoutAndData()); + newLayout); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); @@ -941,8 +943,8 @@ struct WgToSgStoreScatterOpWithOffset if (!valueType) return failure(); - xegpu::LayoutAttr layout = dyn_cast_if_present( - xegpu::getDistributeLayoutAttr(op.getOperand(0))); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 4fbb566cfbe73..5dde84e8e0bc2 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -547,4 +547,21 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout} : index to vector<4x1x1xindex> gpu.return } + + // CHECK-LABEL: distribute_load_slice_attr + gpu.func @distribute_load_slice_attr() { + %2 = memref.alloca() {alignment = 1024} : memref<4096xf32> + %offset = arith.constant {layout_result_0 = #xegpu.layout } dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout } dense<1> : vector<256xi1> + + // CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout, dims = [0]>}> + // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} : + // CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32> + %3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32> + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout} : vector<32xf32> to vector<32x32xf32> + %4 = vector.broadcast %3 {layout_result_0 = + #xegpu.layout} : vector<256xf32> to vector<256x256xf32> + gpu.return + } }