Skip to content

[MLIR][XeGPU] Add distribution pattern for xegpu load & store matrix from sg to wi#183179

Merged
nbpatel merged 6 commits into
llvm:mainfrom
nbpatel:xegpu-load-store-matrix
Mar 5, 2026
Merged

[MLIR][XeGPU] Add distribution pattern for xegpu load & store matrix from sg to wi#183179
nbpatel merged 6 commits into
llvm:mainfrom
nbpatel:xegpu-load-store-matrix

Conversation

@nbpatel
Copy link
Copy Markdown
Contributor

@nbpatel nbpatel commented Feb 24, 2026

This PR adds distribution pattern for xegpu.load_matrix & xegpu.store_matrix ops for the new sg-to-wi pass

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 24, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

This PR adds distribution pattern for xegpu.load_matrix & xegpu.store_matrix ops for the new sg-to-wi pass


Full diff: https://github.com/llvm/llvm-project/pull/183179.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp (+134-2)
  • (modified) mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir (+64)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 3787fbb44e1b8..9433b56b7d882 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -489,6 +490,137 @@ struct SgToWiMultiDimReduction
   }
 };
 
+/// Helper to compute distributed coordinates for matrix ops.
+/// When not using subgroup_block_io, each workitem computes its own
+/// coordinates based on the layout and lane ID.
+static SmallVector<Value> computeDistributedCoordsForMatrixOp(
+    ConversionPatternRewriter &rewriter, Location loc,
+    xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
+    ValueRange origOffsets) {
+  Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                       /*upperBound=*/mlir::IntegerAttr());
+  auto maybeCoords =
+      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+  if (failed(maybeCoords))
+    return {};
+  assert(maybeCoords.value().size() == 1 &&
+         "Expected one set of distributed offsets");
+  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+      getAsOpFoldResult(origOffsets));
+  return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+}
+
+/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
+struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
+  using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = op.getLayoutAttr();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          op, "the matrix op payload must be a vector type");
+
+    auto loc = op.getLoc();
+    auto offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "the load op must have offsets");
+
+    FailureOr<VectorType> distPayloadTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, loc, offsets);
+
+    SmallVector<Value> newCoords = offsetsAsValues;
+    if (!op.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordsForMatrixOp(
+          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+      if (newCoords.empty())
+        return rewriter.notifyMatchFailure(
+            op, "Failed to compute distributed coordinates.");
+    }
+
+    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+                                         ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+    auto newOp = xegpu::LoadMatrixOp::create(
+        rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
+        ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
+        xegpu::DistributeLayoutAttr{});
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
+struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
+  using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = op.getLayoutAttr();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          op, "the matrix op payload must be a vector type");
+
+    auto loc = op.getLoc();
+    auto offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "the store op must have offsets");
+
+    FailureOr<VectorType> distPayloadTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, loc, offsets);
+
+    SmallVector<Value> newCoords = offsetsAsValues;
+    if (!op.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordsForMatrixOp(
+          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+      if (newCoords.empty())
+        return rewriter.notifyMatchFailure(
+            op, "Failed to compute distributed coordinates.");
+    }
+
+    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+                                         ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+    xegpu::StoreMatrixOp::create(
+        rewriter, loc, TypeRange{},
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
+                    distPayloadTyOrFailure.value()),
+        adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
+        op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// This pattern rewrites a subgroup-level vector.multi_reduction op to a series
 /// of vector.extract_strided_slice, vector.reduction and
 /// vector.insert_strided_slice ops. This is used when the reduction dimension
@@ -730,8 +862,8 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
-               SgToWiVectorReduction, SgToWiMultiDimReduction>(
-      typeConverter, patterns.getContext());
+               SgToWiVectorReduction, SgToWiMultiDimReduction, SgToWiLoadMatrix,
+               SgToWiStoreMatrix>(typeConverter, patterns.getContext());
 }
 
 void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 1ec0879d4fb47..24b35228d924a 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -318,3 +318,67 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   gpu.return
 }
 }
+
+// -----
+// load_matrix and store_matrix with coordinate computation (offsets [0,0])
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_1
+// CHECK: %[[LANE_ID1:.*]] = gpu.lane_id
+// CHECK: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
+// CHECK: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
+// CHECK: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
+// CHECK: %[[ROW:.*]] = arith.remui %[[R2]], %{{.*}} : index
+// CHECK: %[[COL:.*]] = arith.remui %[[R1]], %{{.*}} : index
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[COL]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+  xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+  gpu.return
+}
+}
+
+// -----
+// load_matrix and store_matrix with non-zero offsets [0,1]
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_2
+// CHECK: %[[LANE_ID1:.*]] = gpu.lane_id
+// CHECK: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
+// CHECK: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
+// CHECK: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
+// CHECK: %[[MUL:.*]] = arith.muli %[[R2]], %{{.*}} : index
+// CHECK: %[[ROW:.*]] = arith.remui %[[MUL]], %{{.*}} : index
+// CHECK: %[[R3:.*]] = arith.remui %[[R1]], %{{.*}} : index
+// CHECK: %[[ADD:.*]] = arith.addi %[[R3]], %{{.*}} : index
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+  xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+  gpu.return
+}
+}
+
+// -----
+// load_matrix and store_matrix with subgroup_block_io (no coordinate computation)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_3
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
+gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = xegpu.load_matrix %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
+    !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
+  xegpu.store_matrix %1, %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
+    vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
+  gpu.return
+}
+}

/// Helper to compute distributed coordinates for matrix ops.
/// When not using subgroup_block_io, each workitem computes its own
/// coordinates based on the layout and lane ID.
static SmallVector<Value> computeDistributedCoordsForMatrixOp(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is just a copy & paste of computeDistributedCoordinatesForMatrixOp(), please move it to XeGPULayoutImpl

Copy link
Copy Markdown
Contributor Author

@nbpatel nbpatel Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but the other one (in the old pass) will be removed eventually once we move to this pass

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jianhui-Li are you ok with this?

Copy link
Copy Markdown
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Comment thread mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir Outdated
ConversionPatternRewriter &rewriter) const override {
auto layout = op.getLayoutAttr();
// If no layout, nothing to do.
if (!layout)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could generalize it to a legality check for AnchorLayoutInterface ops and omit this check in patterns.

@nbpatel nbpatel merged commit c5c6588 into llvm:main Mar 5, 2026
10 checks passed
@nbpatel nbpatel deleted the xegpu-load-store-matrix branch March 6, 2026 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants