Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Sep 17, 2025

This PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets.

@nbpatel nbpatel marked this pull request as ready for review September 18, 2025 22:43
@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets.


Patch is 20.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159453.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+191-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+91)
  • (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+58-40)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..d7585fa5df8b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
   }
 };
 
+/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
+/// load).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// LoadGatherOp with the unrolled operands.
+struct UnrollLoadGatherOpWithOffset
+    : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+    Value offsets = op.getOffsets();
+    Value mask = op.getMask();
+
+    // Only handle the case where offsets are present (scattered load)
+    if (!offsets)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<int64_t> targetMaskShape(*targetShape);
+    int64_t chunkSize = 1;
+    if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+      if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+        chunkSize = intAttr.getInt();
+    }
+
+    // Unroll mask and offsets with correct shape
+    VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+    VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+    Type elemTy = valueTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+    SmallVector<Type> convertedOffsetTypes;
+    SmallVector<Value> convertedOffsets;
+
+    if (chunkSize > 1) {
+      // For chunked loads, mask and offsets have one less dimension
+      targetMaskShape.pop_back();
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = chunkSize / blockedChunkSize;
+      chunkSize = blockedChunkSize;
+
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+      SmallVector<Value> convertedMasksBase =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+      SmallVector<Value> convertedOffsetsBase =
+          pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+      for (auto maskVal : convertedMasksBase)
+        convertedMasks.append(numNewChunks, maskVal);
+
+      for (auto [baseOffset, offsetType] :
+           llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+                                                     i * blockedChunkSize);
+          Value incVec =
+              vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+          Value offsetVal =
+              arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+          convertedOffsets.push_back(offsetVal);
+        }
+      }
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedMasks =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+      convertedOffsets =
+          pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+    }
+
+    SmallVector<Value> newOps;
+    for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
+      auto newOp = xegpu::LoadGatherOp::create(
+          rewriter, loc, newValueTy, op.getSource(), o, m,
+          rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
+/// store).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// StoreScatterOp with the unrolled operands.
+struct UnrollStoreScatterOpWithOffsets
+    : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    Value offsets = op.getOffsets();
+    Value mask = op.getMask();
+
+    // Only handle the case where offsets are present (scattered store)
+    if (!offsets)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    int64_t chunkSize = 1;
+    if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+      if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+        chunkSize = intAttr.getInt();
+    }
+
+    SmallVector<int64_t> targetMaskShape(*targetShape);
+    VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+    VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+    SmallVector<Type> convertedOffsetTypes;
+    SmallVector<Value> convertedOffsets;
+
+    if (chunkSize > 1) {
+      targetMaskShape.pop_back();
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = chunkSize / blockedChunkSize;
+      chunkSize = blockedChunkSize;
+
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+      SmallVector<Value> convertedMasksBase =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+      SmallVector<Value> convertedOffsetsBase =
+          pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+      for (auto maskVal : convertedMasksBase)
+        convertedMasks.append(numNewChunks, maskVal);
+
+      for (auto [baseOffset, offsetType] :
+           llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+                                                     i * blockedChunkSize);
+          Value incVec =
+              vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+          Value offsetVal =
+              arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+          convertedOffsets.push_back(offsetVal);
+        }
+      }
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedMasks =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+      convertedOffsets =
+          pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+    }
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+
+    for (auto [v, o, m] :
+         llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
+      xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
+                                    rewriter.getI64IntegerAttr(chunkSize),
+                                    op.getL1HintAttr(), op.getL2HintAttr(),
+                                    op.getL3HintAttr());
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
   using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
@@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns(
       .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
            UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
            UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
-           UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
+           UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
+           UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
           patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 6999da5d222fe..1392ded322d0b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -210,6 +210,27 @@ gpu.module @test {
     gpu.return %ld : vector<32xf32>
   }
 
+//-----
+
+
+  // CHECK-LABEL: load_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_with_offsets(%src: memref<64xf32>) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+
+      gpu.return %ld : vector<32xf32>
+  }
+
 //-----
 
   // CHECK-LABEL: prefetch
@@ -254,6 +275,28 @@ gpu.module @test {
 
     gpu.return
   }
+  
+  //-----
+
+  // CHECK-LABEL: store_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets(%src: memref<64xf32>) {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+
+      gpu.return
+  }
 
 //-----
   // CHECK-LABEL: create_tdesc_step_chunk
@@ -319,6 +362,29 @@ gpu.module @test {
     gpu.return %ld : vector<32x4xf32>
    }
 
+//-----
+  // CHECK-LABEL: load_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
+   gpu.func @load_with_offsets_chunk(%src: memref<64xf32>) -> vector<32x4xf32> {
+    %cst = arith.constant dense<[
+        0,   8,  16,  24,  32,  40,  48,  56,
+        64,  72,  80,  88,  96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184,
+        192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
+    gpu.return %ld : vector<32x4xf32>
+   }
+
 //-----
   // CHECK-LABEL: store_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
@@ -342,6 +408,31 @@ gpu.module @test {
     gpu.return
   }
 
+//-----
+  // CHECK-LABEL: store_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets_chunk(%src: memref<64xf32>) {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
+    xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+    gpu.return
+  }
+
 //-----
   // CHECK-LABEL: prefetch_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e1ba45c60ac36..c83faea2e622c 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -50,49 +50,67 @@ struct TestXeGPUUnrollingPatterns
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     xegpu::UnrollOptions options;
-    options.setNativeShapeFn(
-        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
-          if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
-                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
-                  xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
-                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
-              tdescTy = createNdOp.getType();
-            } else if (auto updateNdOp =
-                           dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
-              tdescTy = updateNdOp.getTensorDescType();
-            } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
-              tdescTy = prefetchNdOp.getTensorDescType();
-            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
-              tdescTy = loadNdOp.getTensorDescType();
-            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
-              tdescTy = storeNdOp.getTensorDescType();
-            } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
-              tdescTy = createOp.getType();
-            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
-              tdescTy = updateOp.getTensorDescType();
-            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
-              tdescTy = prefetchOp.getTensorDescType();
-            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
-              tdescTy = loadOp.getTensorDescType();
-            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
-              tdescTy = storeOp.getTensorDescType();
+    options.setNativeShapeFn([&](Operation *op)
+                                 -> std::optional<SmallVector<int64_t>> {
+      if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+              xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+              xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+              xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+        xegpu::TensorDescType tdescTy;
+        if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+          tdescTy = createNdOp.getType();
+        } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+          tdescTy = updateNdOp.getTensorDescType();
+        } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+          tdescTy = prefetchNdOp.getTensorDescType();
+        } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+          tdescTy = loadNdOp.getTensorDescType();
+        } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+          tdescTy = storeNdOp.getTensorDescType();
+        } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+          tdescTy = createOp.getType();
+        } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+          tdescTy = updateOp.getTensorDescType();
+        } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+          tdescTy = prefetchOp.getTensorDescType();
+        } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+          if (loadOp.getOffsets()) {
+            auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult());
+            if (layout && layout.isForSubgroup()) {
+              auto inst_data = layout.getEffectiveInstDataAsInt();
+              if (!inst_data.empty())
+                return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
             }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              auto inst_data = layout.getInstData();
-              if (inst_data && layout.isForSubgroup())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
+            return std::nullopt;
+          }
+          tdescTy = loadOp.getTensorDescType();
+        } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+          if (storeOp.getOffsets()) {
+            auto layout = llvm::dyn_cast_or_null<xegpu::LayoutAttr>(
+                op->getAttr("layout"));
+            if (layout && layout.isForSubgroup()) {
+              auto inst_data = layout.getEffectiveInstDataAsInt();
+              if (!inst_data.empty())
+                return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
             }
+            return std::nullopt;
           }
-
-          if (isa<xegpu::DpasOp>(op))
-            return SmallVector<int64_t>{8, 16, 16};
-
-          return std::nullopt;
-        });
+          tdescTy = storeOp.getTensorDescType();
+        }
+
+        if (auto layout = tdescTy.getLayoutAttr()) {
+          auto inst_data = layout.getInstData();
+          if (inst_data && layout.isForSubgroup())
+            return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                        inst_...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

This PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets.


Patch is 20.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159453.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+191-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+91)
  • (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+58-40)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..d7585fa5df8b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
   }
 };
 
+/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
+/// load).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// LoadGatherOp with the unrolled operands.
+struct UnrollLoadGatherOpWithOffset
+    : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+    Value offsets = op.getOffsets();
+    Value mask = op.getMask();
+
+    // Only handle the case where offsets are present (scattered load)
+    if (!offsets)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<int64_t> targetMaskShape(*targetShape);
+    int64_t chunkSize = 1;
+    if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+      if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+        chunkSize = intAttr.getInt();
+    }
+
+    // Unroll mask and offsets with correct shape
+    VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+    VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+    Type elemTy = valueTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+    SmallVector<Type> convertedOffsetTypes;
+    SmallVector<Value> convertedOffsets;
+
+    if (chunkSize > 1) {
+      // For chunked loads, mask and offsets have one less dimension
+      targetMaskShape.pop_back();
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = chunkSize / blockedChunkSize;
+      chunkSize = blockedChunkSize;
+
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+      SmallVector<Value> convertedMasksBase =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+      SmallVector<Value> convertedOffsetsBase =
+          pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+      for (auto maskVal : convertedMasksBase)
+        convertedMasks.append(numNewChunks, maskVal);
+
+      for (auto [baseOffset, offsetType] :
+           llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+                                                     i * blockedChunkSize);
+          Value incVec =
+              vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+          Value offsetVal =
+              arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+          convertedOffsets.push_back(offsetVal);
+        }
+      }
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedMasks =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+      convertedOffsets =
+          pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+    }
+
+    SmallVector<Value> newOps;
+    for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
+      auto newOp = xegpu::LoadGatherOp::create(
+          rewriter, loc, newValueTy, op.getSource(), o, m,
+          rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
+/// store).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// StoreScatterOp with the unrolled operands.
+struct UnrollStoreScatterOpWithOffsets
+    : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    Value offsets = op.getOffsets();
+    Value mask = op.getMask();
+
+    // Only handle the case where offsets are present (scattered store)
+    if (!offsets)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    int64_t chunkSize = 1;
+    if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+      if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+        chunkSize = intAttr.getInt();
+    }
+
+    SmallVector<int64_t> targetMaskShape(*targetShape);
+    VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+    VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+    SmallVector<Type> convertedOffsetTypes;
+    SmallVector<Value> convertedOffsets;
+
+    if (chunkSize > 1) {
+      targetMaskShape.pop_back();
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = chunkSize / blockedChunkSize;
+      chunkSize = blockedChunkSize;
+
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+      SmallVector<Value> convertedMasksBase =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+      SmallVector<Value> convertedOffsetsBase =
+          pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+      for (auto maskVal : convertedMasksBase)
+        convertedMasks.append(numNewChunks, maskVal);
+
+      for (auto [baseOffset, offsetType] :
+           llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+                                                     i * blockedChunkSize);
+          Value incVec =
+              vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+          Value offsetVal =
+              arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+          convertedOffsets.push_back(offsetVal);
+        }
+      }
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedMasks =
+          pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+      convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+      convertedOffsets =
+          pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+    }
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+
+    for (auto [v, o, m] :
+         llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
+      xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
+                                    rewriter.getI64IntegerAttr(chunkSize),
+                                    op.getL1HintAttr(), op.getL2HintAttr(),
+                                    op.getL3HintAttr());
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
   using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
@@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns(
       .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
            UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
            UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
-           UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
+           UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
+           UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
           patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 6999da5d222fe..1392ded322d0b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -210,6 +210,27 @@ gpu.module @test {
     gpu.return %ld : vector<32xf32>
   }
 
+//-----
+
+
+  // CHECK-LABEL: load_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_with_offsets(%src: memref<64xf32>) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+
+      gpu.return %ld : vector<32xf32>
+  }
+
 //-----
 
   // CHECK-LABEL: prefetch
@@ -254,6 +275,28 @@ gpu.module @test {
 
     gpu.return
   }
+  
+  //-----
+
+  // CHECK-LABEL: store_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets(%src: memref<64xf32>) {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+
+      gpu.return
+  }
 
 //-----
   // CHECK-LABEL: create_tdesc_step_chunk
@@ -319,6 +362,29 @@ gpu.module @test {
     gpu.return %ld : vector<32x4xf32>
    }
 
+//-----
+  // CHECK-LABEL: load_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
+   gpu.func @load_with_offsets_chunk(%src: memref<64xf32>) -> vector<32x4xf32> {
+    %cst = arith.constant dense<[
+        0,   8,  16,  24,  32,  40,  48,  56,
+        64,  72,  80,  88,  96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184,
+        192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
+    gpu.return %ld : vector<32x4xf32>
+   }
+
 //-----
   // CHECK-LABEL: store_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
@@ -342,6 +408,31 @@ gpu.module @test {
     gpu.return
   }
 
+//-----
+  // CHECK-LABEL: store_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets_chunk(%src: memref<64xf32>) {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
+    xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+    gpu.return
+  }
+
 //-----
   // CHECK-LABEL: prefetch_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e1ba45c60ac36..c83faea2e622c 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -50,49 +50,67 @@ struct TestXeGPUUnrollingPatterns
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     xegpu::UnrollOptions options;
-    options.setNativeShapeFn(
-        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
-          if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
-                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
-                  xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
-                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
-              tdescTy = createNdOp.getType();
-            } else if (auto updateNdOp =
-                           dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
-              tdescTy = updateNdOp.getTensorDescType();
-            } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
-              tdescTy = prefetchNdOp.getTensorDescType();
-            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
-              tdescTy = loadNdOp.getTensorDescType();
-            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
-              tdescTy = storeNdOp.getTensorDescType();
-            } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
-              tdescTy = createOp.getType();
-            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
-              tdescTy = updateOp.getTensorDescType();
-            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
-              tdescTy = prefetchOp.getTensorDescType();
-            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
-              tdescTy = loadOp.getTensorDescType();
-            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
-              tdescTy = storeOp.getTensorDescType();
+    options.setNativeShapeFn([&](Operation *op)
+                                 -> std::optional<SmallVector<int64_t>> {
+      if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+              xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+              xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+              xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+        xegpu::TensorDescType tdescTy;
+        if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+          tdescTy = createNdOp.getType();
+        } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+          tdescTy = updateNdOp.getTensorDescType();
+        } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+          tdescTy = prefetchNdOp.getTensorDescType();
+        } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+          tdescTy = loadNdOp.getTensorDescType();
+        } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+          tdescTy = storeNdOp.getTensorDescType();
+        } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+          tdescTy = createOp.getType();
+        } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+          tdescTy = updateOp.getTensorDescType();
+        } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+          tdescTy = prefetchOp.getTensorDescType();
+        } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+          if (loadOp.getOffsets()) {
+            auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult());
+            if (layout && layout.isForSubgroup()) {
+              auto inst_data = layout.getEffectiveInstDataAsInt();
+              if (!inst_data.empty())
+                return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
             }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              auto inst_data = layout.getInstData();
-              if (inst_data && layout.isForSubgroup())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
+            return std::nullopt;
+          }
+          tdescTy = loadOp.getTensorDescType();
+        } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+          if (storeOp.getOffsets()) {
+            auto layout = llvm::dyn_cast_or_null<xegpu::LayoutAttr>(
+                op->getAttr("layout"));
+            if (layout && layout.isForSubgroup()) {
+              auto inst_data = layout.getEffectiveInstDataAsInt();
+              if (!inst_data.empty())
+                return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
             }
+            return std::nullopt;
           }
-
-          if (isa<xegpu::DpasOp>(op))
-            return SmallVector<int64_t>{8, 16, 16};
-
-          return std::nullopt;
-        });
+          tdescTy = storeOp.getTensorDescType();
+        }
+
+        if (auto layout = tdescTy.getLayoutAttr()) {
+          auto inst_data = layout.getInstData();
+          if (inst_data && layout.isForSubgroup())
+            return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                        inst_...
[truncated]

@nbpatel
Copy link
Contributor Author

nbpatel commented Sep 23, 2025

pinging for reviews

Copy link
Contributor

@silee2 silee2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel nbpatel merged commit d235d62 into llvm:main Sep 24, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…th offsets (llvm#159453)

This PR adds unrolling/blocking patterns for load_gather and
store_scatter ops with offsets.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants