Skip to content

[MLIR][XeGPU] Add support for Convert Layout from Wg to Sg#178922

Merged
nbpatel merged 9 commits into
llvm:mainfrom
nbpatel:xegpu-refactor
Feb 13, 2026
Merged

[MLIR][XeGPU] Add support for Convert Layout from Wg to Sg#178922
nbpatel merged 9 commits into
llvm:mainfrom
nbpatel:xegpu-refactor

Conversation

@nbpatel
Copy link
Copy Markdown
Contributor

@nbpatel nbpatel commented Jan 30, 2026

No description provided.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Jan 30, 2026

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/178922.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+4-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+105-25)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+78)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..fefc8c9903497 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   }];
 
   let description = [{
-    This operation loads a 2D block of data from shared local memory (SLM) as specified
-    by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
-    subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+    This operation loads an nD block of data from shared local memory (SLM) as specified
+    by the provided nD `mem_desc`. Memory descriptors of any rank are supported.
 
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.
@@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
   let description = [{
-    This operation stores a 2D `data` fragment into the shared local memory region
-    specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
-    subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+    This operation stores an nD `data` fragment into the shared local memory region
+    specified by an nD `mem_desc`. Memory descriptors of any rank are supported.
 
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..c7226c7ebbd5d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
       return success();
   }
 
-  if (mdescTy.getRank() != 2)
-    return emitError() << "mem_desc must be 2D.";
+  if (mdescTy.getRank() < 2)
+    return emitError() << "mem_desc must be 2D or greater.";
 
   ArrayRef<int64_t> dataShape = dataTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..8dbca952cc8c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -604,44 +604,124 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 struct WgToSgConvertLayoutOp
     : public OpConversionPattern<xegpu::ConvertLayoutOp> {
   using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
 
-    auto input = op.getInputLayout();
-    auto target = op.getTargetLayout();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    auto inputLayout = op.getInputLayout();
+    auto targetLayout = op.getTargetLayout();
 
-    if (!input || !target || !input.isForWorkgroup() ||
-        !target.isForWorkgroup())
+    if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+        !targetLayout.isForWorkgroup())
       return rewriter.notifyMatchFailure(
           op, "Input and target layouts must have subgroup layout");
 
-    SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr inputOrder = input.getOrder();
-    SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr targetOrder = target.getOrder();
-
-    // TODO: currently we only support for optimal case, where input and
-    // output has the same sg_layout and sg_data, so SLM is not involved.
-    if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
-        inputOrder != targetOrder)
+    SmallVector<int64_t> inputSgLayout =
+        inputLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> targetSgLayout =
+        targetLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+    auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+      if (shape.size() <= 2)
+        return true;
+      for (size_t i = 0; i + 2 < shape.size(); ++i)
+        if (shape[i] != 1)
+          return false;
+      return true;
+    };
+
+    if (wgShape.size() > 2) {
+      if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+        return rewriter.notifyMatchFailure(
+            op, "rank > 2 requires unit leading dims for sg_data");
+    }
+
+    // Fast path: if sg_layout and sg_data are identical, no SLM needed
+    if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+      inputLayout = inputLayout.dropSgLayoutAndData();
+      targetLayout = targetLayout.dropSgLayoutAndData();
+
+      SmallVector<Value> newOps(adaptor.getSource());
+      if (inputLayout && targetLayout) {
+        for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+          auto newOp = xegpu::ConvertLayoutOp::create(
+              rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+          newOps[i] = newOp;
+        }
+      }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+      return success();
+    }
+
+    // SLM path: layouts differ, need cross-subgroup data redistribution
+    Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+    SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+    // Calculate SLM size requirements
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto bytesPerElement = bitWidth / 8;
+    auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+    // Allocate SLM
+    auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+    auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+                                               elemTy, nullptr);
+    auto memDesc =
+        xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+    auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), nullptr);
+
+    // STORE PHASE: Each subgroup stores in SLM using input layout
+    auto storeCoords = inputLayout.computeDistributedCoords(
+        rewriter, loc, sgId.getResult(), wgShape);
+    if (failed(storeCoords))
       return failure();
 
-    input = input.dropSgLayoutAndData();
-    target = target.dropSgLayoutAndData();
+    // Store to SLM
+    for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
+      SmallVector<OpFoldResult> storeMatrixOffsets;
+      for (Value coord : coords) {
+        storeMatrixOffsets.push_back(coord);
+      }
+      xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
+                                   storeMatrixOffsets, nullptr /*layout*/);
+    }
+
+    gpu::BarrierOp::create(rewriter, loc);
+
+    // LOAD PHASE: Each target subgroup loads from SLM using target layout
+    auto loadCoords = targetLayout.computeDistributedCoords(
+        rewriter, loc, sgId.getResult(), wgShape);
+    if (failed(loadCoords))
+      return failure();
+
+    VectorType loadType = VectorType::get(targetSgData, elemTy);
 
-    SmallVector<Value> newOps(adaptor.getSource());
-    if (input && target) {
-      // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
-      for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
-        auto newOp = xegpu::ConvertLayoutOp::create(
-            rewriter, op.getLoc(), src.getType(), src, input, target);
-        newOps[i] = newOp;
+    // Load vectors from SLM
+    SmallVector<Value> finalResults;
+    for (auto coords : *loadCoords) {
+      SmallVector<OpFoldResult> loadMatrixOffsets;
+      for (Value coord : coords) {
+        loadMatrixOffsets.push_back(coord);
       }
+      auto loadOp = xegpu::LoadMatrixOp::create(
+          rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
+          targetLayout.dropSgLayoutAndData());
+
+      finalResults.push_back(loadOp.getResult());
     }
-    rewriter.replaceOpWithMultiple(op, {newOps});
+
+    rewriter.replaceOpWithMultiple(op, {finalResults});
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..e6376e3ecb4cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>)
 
 // -----
 func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
-  // expected-error@+1 {{mem_desc must be 2D}}
+  // expected-error@+1 {{mem_desc must be 2D or greater}}
   %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
   return
 }
@@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %
 
 // -----
 func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
-  // expected-error@+1 {{mem_desc must be 2D.}}
+  // expected-error@+1 {{mem_desc must be 2D or greater}}
   xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..d4b611c713674 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -839,6 +839,84 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: convert_layout_slm
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
+  gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]],  %[[C16:.*]] : index
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
+    // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
+    // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
+    %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
+    %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
+                                   target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
+    gpu.return
+  }
+
+  gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<0> : vector<1x32x16xindex>
+    // CHECK-DAG: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<true> : vector<1x32x16xi1>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
+    // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [1, 16, 16]>}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<0> : vector<8x128x256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<true> : vector<8x128x256xi1>
+    %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} : memref<?xf32>, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32>
+    %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>,
+                                   target_layout = #xegpu.layout<sg_layout = [8, 8, 8], sg_data = [1, 16, 32], inst_data = [1, 16, 16]>}> : vector<8x128x256xf32>
+    gpu.return
+  }
+
   // CHECK-LABEL: distribute_nested_slice
   // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
   // CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Jan 30, 2026

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/178922.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+4-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+105-25)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+78)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..fefc8c9903497 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   }];
 
   let description = [{
-    This operation loads a 2D block of data from shared local memory (SLM) as specified
-    by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
-    subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+    This operation loads an nD block of data from shared local memory (SLM) as specified
+    by the provided nD `mem_desc`. Memory descriptors of any rank are supported.
 
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.
@@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
   let description = [{
-    This operation stores a 2D `data` fragment into the shared local memory region
-    specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
-    subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+    This operation stores an nD `data` fragment into the shared local memory region
+    specified by an nD `mem_desc`. Memory descriptors of any rank are supported.
 
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..c7226c7ebbd5d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
       return success();
   }
 
-  if (mdescTy.getRank() != 2)
-    return emitError() << "mem_desc must be 2D.";
+  if (mdescTy.getRank() < 2)
+    return emitError() << "mem_desc must be 2D or greater.";
 
   ArrayRef<int64_t> dataShape = dataTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..8dbca952cc8c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -604,44 +604,124 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 struct WgToSgConvertLayoutOp
     : public OpConversionPattern<xegpu::ConvertLayoutOp> {
   using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
 
-    auto input = op.getInputLayout();
-    auto target = op.getTargetLayout();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    auto inputLayout = op.getInputLayout();
+    auto targetLayout = op.getTargetLayout();
 
-    if (!input || !target || !input.isForWorkgroup() ||
-        !target.isForWorkgroup())
+    if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+        !targetLayout.isForWorkgroup())
       return rewriter.notifyMatchFailure(
           op, "Input and target layouts must have subgroup layout");
 
-    SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr inputOrder = input.getOrder();
-    SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr targetOrder = target.getOrder();
-
-    // TODO: currently we only support for optimal case, where input and
-    // output has the same sg_layout and sg_data, so SLM is not involved.
-    if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
-        inputOrder != targetOrder)
+    SmallVector<int64_t> inputSgLayout =
+        inputLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> targetSgLayout =
+        targetLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+    auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+      if (shape.size() <= 2)
+        return true;
+      for (size_t i = 0; i + 2 < shape.size(); ++i)
+        if (shape[i] != 1)
+          return false;
+      return true;
+    };
+
+    if (wgShape.size() > 2) {
+      if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+        return rewriter.notifyMatchFailure(
+            op, "rank > 2 requires unit leading dims for sg_data");
+    }
+
+    // Fast path: if sg_layout and sg_data are identical, no SLM needed
+    if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+      inputLayout = inputLayout.dropSgLayoutAndData();
+      targetLayout = targetLayout.dropSgLayoutAndData();
+
+      SmallVector<Value> newOps(adaptor.getSource());
+      if (inputLayout && targetLayout) {
+        for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+          auto newOp = xegpu::ConvertLayoutOp::create(
+              rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+          newOps[i] = newOp;
+        }
+      }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+      return success();
+    }
+
+    // SLM path: layouts differ, need cross-subgroup data redistribution
+    Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+    SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+    // Calculate SLM size requirements
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto bytesPerElement = bitWidth / 8;
+    auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+    // Allocate SLM
+    auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+    auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+                                               elemTy, nullptr);
+    auto memDesc =
+        xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+    auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), nullptr);
+
+    // STORE PHASE: Each subgroup stores in SLM using input layout
+    auto storeCoords = inputLayout.computeDistributedCoords(
+        rewriter, loc, sgId.getResult(), wgShape);
+    if (failed(storeCoords))
       return failure();
 
-    input = input.dropSgLayoutAndData();
-    target = target.dropSgLayoutAndData();
+    // Store to SLM
+    for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
+      SmallVector<OpFoldResult> storeMatrixOffsets;
+      for (Value coord : coords) {
+        storeMatrixOffsets.push_back(coord);
+      }
+      xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
+                                   storeMatrixOffsets, nullptr /*layout*/);
+    }
+
+    gpu::BarrierOp::create(rewriter, loc);
+
+    // LOAD PHASE: Each target subgroup loads from SLM using target layout
+    auto loadCoords = targetLayout.computeDistributedCoords(
+        rewriter, loc, sgId.getResult(), wgShape);
+    if (failed(loadCoords))
+      return failure();
+
+    VectorType loadType = VectorType::get(targetSgData, elemTy);
 
-    SmallVector<Value> newOps(adaptor.getSource());
-    if (input && target) {
-      // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
-      for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
-        auto newOp = xegpu::ConvertLayoutOp::create(
-            rewriter, op.getLoc(), src.getType(), src, input, target);
-        newOps[i] = newOp;
+    // Load vectors from SLM
+    SmallVector<Value> finalResults;
+    for (auto coords : *loadCoords) {
+      SmallVector<OpFoldResult> loadMatrixOffsets;
+      for (Value coord : coords) {
+        loadMatrixOffsets.push_back(coord);
       }
+      auto loadOp = xegpu::LoadMatrixOp::create(
+          rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
+          targetLayout.dropSgLayoutAndData());
+
+      finalResults.push_back(loadOp.getResult());
     }
-    rewriter.replaceOpWithMultiple(op, {newOps});
+
+    rewriter.replaceOpWithMultiple(op, {finalResults});
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..e6376e3ecb4cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>)
 
 // -----
 func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
-  // expected-error@+1 {{mem_desc must be 2D}}
+  // expected-error@+1 {{mem_desc must be 2D or greater}}
   %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
   return
 }
@@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %
 
 // -----
 func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
-  // expected-error@+1 {{mem_desc must be 2D.}}
+  // expected-error@+1 {{mem_desc must be 2D or greater}}
   xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..d4b611c713674 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -839,6 +839,84 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: convert_layout_slm
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
+  gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]],  %[[C16:.*]] : index
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
+    // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
+    // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
+    %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
+    %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
+                                   target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
+    gpu.return
+  }
+
+  gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<0> : vector<1x32x16xindex>
+    // CHECK-DAG: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<true> : vector<1x32x16xi1>
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
+    // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
+    // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index
+    // CHECK-DAG: gpu.barrier
+    // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+    // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+    // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [1, 16, 16]>}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<0> : vector<8x128x256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<true> : vector<8x128x256xi1>
+    %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} : memref<?xf32>, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32>
+    %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>,
+                                   target_layout = #xegpu.layout<sg_layout = [8, 8, 8], sg_data = [1, 16, 32], inst_data = [1, 16, 16]>}> : vector<8x128x256xf32>
+    gpu.return
+  }
+
   // CHECK-LABEL: distribute_nested_slice
   // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
   // CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>

}

// Fast path: if sg_layout and sg_data are identical, no SLM needed
if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use isEqualTo()
or consider adding isCompatible() to xegpuattr.td and use it here.

targetLayout.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();

auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is necessary. We should enhance the inst_data layout and subgroup distribution to get it supported.

Copy link
Copy Markdown
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % comments


// Allocate SLM
auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we consider moving this (alloc) to begin of the function?

alloc'ing inside control-flow structures is less ideal.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand the question..where is the cf here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean allocating memory inside loops would be not great.

}
rewriter.replaceOpWithMultiple(op, {newOps});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we insert another barrier here?

if not a thread may store to the same slm before consumer is loading it back.

@nbpatel nbpatel temporarily deployed to main-branch-only February 12, 2026 17:25 — with GitHub Actions Inactive

// Fast path: if sg_layout and sg_data are identical, no SLM needed
if (llvm::equal(inputSgLayout, targetSgLayout) &&
llvm::equal(inputSgData, targetSgData)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant isEqualTo() in XeGPUAttrs.td

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but won't that check for inst_data and other fields as well?

@nbpatel nbpatel temporarily deployed to main-branch-only February 12, 2026 22:17 — with GitHub Actions Inactive
@nbpatel nbpatel temporarily deployed to main-branch-only February 13, 2026 01:25 — with GitHub Actions Inactive
@nbpatel nbpatel temporarily deployed to main-branch-only February 13, 2026 16:19 — with GitHub Actions Inactive
@nbpatel nbpatel merged commit d521863 into llvm:main Feb 13, 2026
10 checks passed
@nbpatel nbpatel deleted the xegpu-refactor branch April 15, 2026 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants