Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Sep 30, 2025

The PR adds a change to use the layouts from the operands since store doesn't have a result

@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

The PR adds a change to use the layouts from the operands since store doesn't have a result


Full diff: https://github.com/llvm/llvm-project/pull/161447.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+14-10)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+12-7)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..784e5d68ce885 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset
       return failure();
 
     xegpu::DistributeLayoutAttr layout =
-        xegpu::getDistributeLayoutAttr(op.getValue());
+        xegpu::getDistributeLayoutAttr(op.getOperand(0));
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset
     auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
     for (auto [val, offs, mask] : llvm::zip(
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
-      xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
-                                    mask, chunkSizeAttr, op.getL1HintAttr(),
-                                    op.getL2HintAttr(), op.getL3HintAttr());
+      auto store = xegpu::StoreScatterOp::create(
+          rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
       // Update the layout attribute to drop sg_layout and sg_data.
-      if (auto newLayout = layout.dropSgLayoutAndData())
-        op->setAttr("layout", newLayout);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        for (OpOperand &operand : store->getOpOperands()) {
+          // Skip for operand one (memref)
+          if (operand.getOperandNumber() == 1)
+            continue;
+          xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
+        }
+      }
     }
     rewriter.eraseOp(op);
     return success();
@@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
       [=](xegpu::StoreScatterOp op) -> bool {
-        // Check if the layout attribute is present on the result.
-        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
-        if (!layout)
-          return true;
+        auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
         return isLegal(layout);
       });
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 03c63861705d9..38392fd10b742 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -282,15 +282,20 @@ gpu.module @test_distribution {
   // CHECK-LABEL: @store_scatter
   // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
   gpu.func @store_scatter(%dest : memref<256xf16>) {
-    // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
-    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
-    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
+    // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
     // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+    // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
+    // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
     // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
-    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
-    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
-    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
-    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
+    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, 
+                                             layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+                                             layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+                                             l1_hint = #xegpu.cache_hint<cached>}
       : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
     gpu.return
   }

@nbpatel nbpatel merged commit 68b143d into llvm:main Oct 2, 2025
12 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
The PR adds a change to use the layouts from the operands since store
doesn't have a result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants