Skip to content

Conversation

@dchigarev
Copy link
Contributor

@dchigarev dchigarev commented Nov 13, 2025

The PR changes the layout attribute type for xegpu::LoadGatherOp/StoreScatterOp from LayoutAttr to DistributeLayoutAttr to also support xegpu.slice layouts.

Initially we wanted to restrict slice layouts from the attribute, but now it turns out there are actually valid use cases for that:

gpu.func @distribute_load_slice_attr() {
  %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
  %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

  %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} {
      layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> 
  } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

  %4 = vector.broadcast %3 {layout_result_0 =
      #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
  gpu.return
}

…r/scatter ops

Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
@dchigarev dchigarev marked this pull request as ready for review November 13, 2025 09:28
@dchigarev
Copy link
Contributor Author

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Dmitry Chigarev (dchigarev)

Changes

The PR changes the layout attribute type for xegpu::LoadGatherOp/StoreScatterOp from LayoutAttr to DistributeLayoutAttr to also support xegpu.slice layouts.


Full diff: https://github.com/llvm/llvm-project/pull/167850.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+4-4)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+2-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+7-5)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 689ebd0d1179a..4c67856b559b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -844,7 +844,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
-      OptionalAttr<XeGPU_LayoutAttr>:$layout);
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -903,7 +903,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint,
-                    "xegpu::LayoutAttr": $layout)>
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -988,7 +988,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
-      OptionalAttr<XeGPU_LayoutAttr>:$layout);
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -1046,7 +1046,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint,
-                    "xegpu::LayoutAttr": $layout)>
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4dd10bedc6d84..85c9a966f0fe8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -901,7 +901,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint,
-                         xegpu::LayoutAttr layout) {
+                         DistributeLayoutAttr layout) {
   auto loc = source.getLoc();
   int64_t size = static_cast<int64_t>(offsets.size());
   auto type = VectorType::get(size, builder.getIndexType());
@@ -985,7 +985,7 @@ void StoreScatterOp::build(
     OpBuilder &builder, OperationState &state, Value value, Value dest,
     ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
     xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
-    xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
   auto loc = dest.getLoc();
   int64_t size = static_cast<int64_t>(offsets.size());
   auto type = VectorType::get(size, builder.getIndexType());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c3bf9606693a8..330553564f81a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -678,7 +678,7 @@ struct UnrollLoadGatherOpWithOffset
           pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
     }
 
-    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    auto layout = op.getLayoutAttr();
     if (layout)
       layout = layout.dropInstData();
 
@@ -778,7 +778,7 @@ struct UnrollStoreScatterOpWithOffsets
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
-    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    auto layout = op.getLayoutAttr();
     if (layout)
       layout = layout.dropInstData();
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0aa6df96..33d4b0457e5d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -889,8 +889,8 @@ struct WgToSgLoadGatherOpWithOffset
       return failure();
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
-        xegpu::getDistributeLayoutAttr(op.getResult()));
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -913,10 +913,12 @@ struct WgToSgLoadGatherOpWithOffset
     VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
     for (auto [offsets, mask] :
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+      auto newLayout = layout.dropSgLayoutAndData();
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
           op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
-          layout.dropSgLayoutAndData());
+          newLayout);
+      xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
       newLoadOps.push_back(newLoadOp);
     }
     rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -941,8 +943,8 @@ struct WgToSgStoreScatterOpWithOffset
     if (!valueType)
       return failure();
 
-    xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
-        xegpu::getDistributeLayoutAttr(op.getOperand(0)));
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getOperand(0));
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 4fbb566cfbe73..5dde84e8e0bc2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -547,4 +547,21 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_load_slice_attr
+  gpu.func @distribute_load_slice_attr() {
+    %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
+    %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>
+
+    // CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}>
+    // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} :
+    // CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+    %3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>
+
+    // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32>
+    %4 = vector.broadcast %3 {layout_result_0 =
+        #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
+    gpu.return
+  }
 }

@dchigarev dchigarev changed the title [XeGPU] Use DistributeLayoutAttr instead of LayoutAttr for load gather/scatter ops [mlir][XeGPU] Use DistributeLayoutAttr instead of LayoutAttr for load gather/scatter ops Nov 13, 2025
Copy link
Contributor

@nbpatel nbpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks for the change.

@charithaintc charithaintc merged commit cd5d5b3 into llvm:main Nov 17, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants