Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 18, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 18, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/168618.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+8)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+25-40)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 3f27d690f949b..c464c156e1fad 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -635,6 +635,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
       parent = parent.dropSgLayoutAndData();
+      if (!parent)
+        return nullptr;
+      if(!parent.getInstData() && !parent.getLaneLayout())
+        return nullptr;
       return SliceAttr::get(getContext(), parent, attr.getDims());
     }
 
@@ -642,6 +646,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
       parent = parent.dropInstData();
+      if (!parent)
+        return nullptr;
+      if (!parent.getSgLayout() && !parent.getLaneLayout())
+        return nullptr;
       return SliceAttr::get(getContext(), parent, attr.getDims());
     }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 33d4b0457e5d3..8d4e1d11e2873 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -489,10 +489,8 @@ struct WgToSgVectorBroadcastOp
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty())
-        xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
-                                       layout.dropSgLayoutAndData());
+      xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+                                     layout.dropSgLayoutAndData());
 
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
@@ -738,12 +736,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     Location loc = op.getLoc();
     auto eltType = vecType.getElementType();
 
-    auto setLayoutIfNeeded = [&](Value val) {
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty()) {
-        xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
-                                       layout.dropSgLayoutAndData());
-      }
+    auto setLayout = [&](Value val) {
+      xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+                                     layout.dropSgLayoutAndData());
     };
 
     if (vecAttr.isSplat()) {
@@ -751,14 +746,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       Attribute singleVal = vecAttr.getSplatValue<Attribute>();
       auto sgAttr = DenseElementsAttr::get(newType, singleVal);
       auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
-      setLayoutIfNeeded(cstOp->getResult(0));
+      setLayout(cstOp->getResult(0));
       rewriter.replaceOp(op, cstOp);
       return success();
     } else if (sgShape == wgShape) { // if the entire vector is shared by all
                                      // subgroups, don't distribute
       auto newConstOp =
           arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
-      setLayoutIfNeeded(newConstOp->getResult(0));
+      setLayout(newConstOp->getResult(0));
       rewriter.replaceOp(op, newConstOp);
       return success();
     } else {
@@ -860,9 +855,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
             rewriter, loc, baseConstVec.getType(), mulOffset);
         auto finalConst =
             arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
-        setLayoutIfNeeded(baseConstVec);
-        setLayoutIfNeeded(bcastOffset);
-        setLayoutIfNeeded(finalConst);
+        setLayout(baseConstVec);
+        setLayout(bcastOffset);
+        setLayout(finalConst);
         newConstOps.push_back(finalConst);
       }
       rewriter.replaceOpWithMultiple(op, {newConstOps});
@@ -969,14 +964,11 @@ struct WgToSgStoreScatterOpWithOffset
           op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
           layout.dropSgLayoutAndData());
       // Update the layout attribute to drop sg_layout and sg_data.
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty()) {
-        for (OpOperand &operand : store->getOpOperands()) {
-          // Skip for operand one (memref)
-          if (operand.getOperandNumber() == 1)
-            continue;
-          xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
-        }
+      for (OpOperand &operand : store->getOpOperands()) {
+        // Skip for operand one (memref)
+        if (operand.getOperandNumber() == 1)
+          continue;
+        xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
       }
     }
     rewriter.eraseOp(op);
@@ -1069,15 +1061,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
           vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
       auto finalSteps =
           arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty()) {
-        xegpu::setDistributeLayoutAttr(steps->getResult(0),
-                                       layout.dropSgLayoutAndData());
-        xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
-                                       layout.dropSgLayoutAndData());
-        xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
-                                       layout.dropSgLayoutAndData());
-      }
+      xegpu::setDistributeLayoutAttr(steps->getResult(0),
+                                     layout.dropSgLayoutAndData());
+      xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+                                     layout.dropSgLayoutAndData());
+      xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+                                     layout.dropSgLayoutAndData());
       newOps.push_back(finalSteps);
     }
 
@@ -1145,10 +1134,8 @@ struct WgToSgVectorShapeCastOp
     for (auto src : adaptor.getSource()) {
       auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
                                                       newResultType, src);
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty())
-        xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
-                                       layout.dropSgLayoutAndData());
+      xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+                                     layout.dropSgLayoutAndData());
       newShapeCastOps.push_back(newShapeCast.getResult());
     }
 
@@ -1209,10 +1196,8 @@ struct WgToSgMultiDimReductionOp
       auto newOp = vector::MultiDimReductionOp::create(
           rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
           adaptor.getAcc()[0], op.getReductionDims());
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty())
-        xegpu::setDistributeLayoutAttr(newOp->getResult(0),
-                                       layout.dropSgLayoutAndData());
+      xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+                                     layout.dropSgLayoutAndData());
       newReductions.push_back(newOp.getResult());
     }
 

@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7083 tests passed
  • 594 tests skipped

@nbpatel nbpatel merged commit 778e104 into llvm:main Nov 21, 2025
10 of 11 checks passed
aadeshps-mcw pushed a commit to aadeshps-mcw/llvm-project that referenced this pull request Nov 26, 2025
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants