-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU] Add unroll pattern for load_gather and store_scatter with offsets #159453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets. Patch is 20.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159453.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..d7585fa5df8b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}
};
+/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
+/// load).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// LoadGatherOp with the unrolled operands.
+struct UnrollLoadGatherOpWithOffset
+ : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+ Value offsets = op.getOffsets();
+ Value mask = op.getMask();
+
+ // Only handle the case where offsets are present (scattered load)
+ if (!offsets)
+ return failure();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<int64_t> targetMaskShape(*targetShape);
+ int64_t chunkSize = 1;
+ if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+ chunkSize = intAttr.getInt();
+ }
+
+ // Unroll mask and offsets with correct shape
+ VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+ VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+ Type elemTy = valueTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+ SmallVector<Type> convertedOffsetTypes;
+ SmallVector<Value> convertedOffsets;
+
+ if (chunkSize > 1) {
+ // For chunked loads, mask and offsets have one less dimension
+ targetMaskShape.pop_back();
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = chunkSize / blockedChunkSize;
+ chunkSize = blockedChunkSize;
+
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+ SmallVector<Value> convertedMasksBase =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+ SmallVector<Value> convertedOffsetsBase =
+ pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+ for (auto maskVal : convertedMasksBase)
+ convertedMasks.append(numNewChunks, maskVal);
+
+ for (auto [baseOffset, offsetType] :
+ llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+ i * blockedChunkSize);
+ Value incVec =
+ vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+ Value offsetVal =
+ arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+ convertedOffsets.push_back(offsetVal);
+ }
+ }
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedMasks =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+ convertedOffsets =
+ pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
+
+ SmallVector<Value> newOps;
+ for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
+ auto newOp = xegpu::LoadGatherOp::create(
+ rewriter, loc, newValueTy, op.getSource(), o, m,
+ rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
+/// store).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// StoreScatterOp with the unrolled operands.
+struct UnrollStoreScatterOpWithOffsets
+ : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ Value offsets = op.getOffsets();
+ Value mask = op.getMask();
+
+ // Only handle the case where offsets are present (scattered store)
+ if (!offsets)
+ return failure();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ int64_t chunkSize = 1;
+ if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+ chunkSize = intAttr.getInt();
+ }
+
+ SmallVector<int64_t> targetMaskShape(*targetShape);
+ VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+ VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+ SmallVector<Type> convertedOffsetTypes;
+ SmallVector<Value> convertedOffsets;
+
+ if (chunkSize > 1) {
+ targetMaskShape.pop_back();
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = chunkSize / blockedChunkSize;
+ chunkSize = blockedChunkSize;
+
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+ SmallVector<Value> convertedMasksBase =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+ SmallVector<Value> convertedOffsetsBase =
+ pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+ for (auto maskVal : convertedMasksBase)
+ convertedMasks.append(numNewChunks, maskVal);
+
+ for (auto [baseOffset, offsetType] :
+ llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+ i * blockedChunkSize);
+ Value incVec =
+ vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+ Value offsetVal =
+ arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+ convertedOffsets.push_back(offsetVal);
+ }
+ }
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedMasks =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+ convertedOffsets =
+ pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+
+ for (auto [v, o, m] :
+ llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
+ xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
+ rewriter.getI64IntegerAttr(chunkSize),
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
@@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns(
.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
- UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
+ UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
+ UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 6999da5d222fe..1392ded322d0b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -210,6 +210,27 @@ gpu.module @test {
gpu.return %ld : vector<32xf32>
}
+//-----
+
+
+ // CHECK-LABEL: load_with_offsets
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+ gpu.func @load_with_offsets(%src: memref<64xf32>) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+ %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+
+ gpu.return %ld : vector<32xf32>
+ }
+
//-----
// CHECK-LABEL: prefetch
@@ -254,6 +275,28 @@ gpu.module @test {
gpu.return
}
+
+ //-----
+
+ // CHECK-LABEL: store_with_offsets
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+ gpu.func @store_with_offsets(%src: memref<64xf32>) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+ xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+
+ gpu.return
+ }
//-----
// CHECK-LABEL: create_tdesc_step_chunk
@@ -319,6 +362,29 @@ gpu.module @test {
gpu.return %ld : vector<32x4xf32>
}
+//-----
+ // CHECK-LABEL: load_with_offsets_chunk
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
+ // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+ // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+ // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+ // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
+ gpu.func @load_with_offsets_chunk(%src: memref<64xf32>) -> vector<32x4xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+ %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
+ gpu.return %ld : vector<32x4xf32>
+ }
+
//-----
// CHECK-LABEL: store_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
@@ -342,6 +408,31 @@ gpu.module @test {
gpu.return
}
+//-----
+ // CHECK-LABEL: store_with_offsets_chunk
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
+ // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+ // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+ // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+ // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+ // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+ gpu.func @store_with_offsets_chunk(%src: memref<64xf32>) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
+ xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+ gpu.return
+ }
+
//-----
// CHECK-LABEL: prefetch_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e1ba45c60ac36..c83faea2e622c 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -50,49 +50,67 @@ struct TestXeGPUUnrollingPatterns
void runOnOperation() override {
MLIRContext *ctx = &getContext();
xegpu::UnrollOptions options;
- options.setNativeShapeFn(
- [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
- if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
- xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
- xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
- xegpu::TensorDescType tdescTy;
- if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
- tdescTy = createNdOp.getType();
- } else if (auto updateNdOp =
- dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
- tdescTy = updateNdOp.getTensorDescType();
- } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
- tdescTy = prefetchNdOp.getTensorDescType();
- } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
- tdescTy = loadNdOp.getTensorDescType();
- } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
- tdescTy = storeNdOp.getTensorDescType();
- } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
- tdescTy = createOp.getType();
- } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
- tdescTy = updateOp.getTensorDescType();
- } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
- tdescTy = prefetchOp.getTensorDescType();
- } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
- tdescTy = loadOp.getTensorDescType();
- } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
- tdescTy = storeOp.getTensorDescType();
+ options.setNativeShapeFn([&](Operation *op)
+ -> std::optional<SmallVector<int64_t>> {
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+ tdescTy = createNdOp.getType();
+ } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+ tdescTy = updateNdOp.getTensorDescType();
+ } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+ tdescTy = prefetchNdOp.getTensorDescType();
+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+ tdescTy = loadNdOp.getTensorDescType();
+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+ tdescTy = storeNdOp.getTensorDescType();
+ } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ tdescTy = createOp.getType();
+ } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ tdescTy = updateOp.getTensorDescType();
+ } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+ tdescTy = prefetchOp.getTensorDescType();
+ } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ if (loadOp.getOffsets()) {
+ auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult());
+ if (layout && layout.isForSubgroup()) {
+ auto inst_data = layout.getEffectiveInstDataAsInt();
+ if (!inst_data.empty())
+ return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
}
-
- if (auto layout = tdescTy.getLayoutAttr()) {
- auto inst_data = layout.getInstData();
- if (inst_data && layout.isForSubgroup())
- return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
- inst_data.asArrayRef().end());
+ return std::nullopt;
+ }
+ tdescTy = loadOp.getTensorDescType();
+ } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+ if (storeOp.getOffsets()) {
+ auto layout = llvm::dyn_cast_or_null<xegpu::LayoutAttr>(
+ op->getAttr("layout"));
+ if (layout && layout.isForSubgroup()) {
+ auto inst_data = layout.getEffectiveInstDataAsInt();
+ if (!inst_data.empty())
+ return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
}
+ return std::nullopt;
}
-
- if (isa<xegpu::DpasOp>(op))
- return SmallVector<int64_t>{8, 16, 16};
-
- return std::nullopt;
- });
+ tdescTy = storeOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isForSubgroup())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesThis PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets. Patch is 20.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159453.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..d7585fa5df8b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}
};
+/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
+/// load).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// LoadGatherOp with the unrolled operands.
+struct UnrollLoadGatherOpWithOffset
+ : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+ Value offsets = op.getOffsets();
+ Value mask = op.getMask();
+
+ // Only handle the case where offsets are present (scattered load)
+ if (!offsets)
+ return failure();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<int64_t> targetMaskShape(*targetShape);
+ int64_t chunkSize = 1;
+ if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+ chunkSize = intAttr.getInt();
+ }
+
+ // Unroll mask and offsets with correct shape
+ VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+ VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+ Type elemTy = valueTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+ SmallVector<Type> convertedOffsetTypes;
+ SmallVector<Value> convertedOffsets;
+
+ if (chunkSize > 1) {
+ // For chunked loads, mask and offsets have one less dimension
+ targetMaskShape.pop_back();
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = chunkSize / blockedChunkSize;
+ chunkSize = blockedChunkSize;
+
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+ SmallVector<Value> convertedMasksBase =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+ SmallVector<Value> convertedOffsetsBase =
+ pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+ for (auto maskVal : convertedMasksBase)
+ convertedMasks.append(numNewChunks, maskVal);
+
+ for (auto [baseOffset, offsetType] :
+ llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+ i * blockedChunkSize);
+ Value incVec =
+ vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+ Value offsetVal =
+ arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+ convertedOffsets.push_back(offsetVal);
+ }
+ }
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedMasks =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+ convertedOffsets =
+ pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
+
+ SmallVector<Value> newOps;
+ for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
+ auto newOp = xegpu::LoadGatherOp::create(
+ rewriter, loc, newValueTy, op.getSource(), o, m,
+ rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
+/// store).
+/// It unrolls the offsets and mask operands accordingly, and creates multiple
+/// StoreScatterOp with the unrolled operands.
+struct UnrollStoreScatterOpWithOffsets
+ : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ Value offsets = op.getOffsets();
+ Value mask = op.getMask();
+
+ // Only handle the case where offsets are present (scattered store)
+ if (!offsets)
+ return failure();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ int64_t chunkSize = 1;
+ if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
+ chunkSize = intAttr.getInt();
+ }
+
+ SmallVector<int64_t> targetMaskShape(*targetShape);
+ VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
+ VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
+
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+ SmallVector<Type> convertedOffsetTypes;
+ SmallVector<Value> convertedOffsets;
+
+ if (chunkSize > 1) {
+ targetMaskShape.pop_back();
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = chunkSize / blockedChunkSize;
+ chunkSize = blockedChunkSize;
+
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
+
+ SmallVector<Value> convertedMasksBase =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+ SmallVector<Value> convertedOffsetsBase =
+ pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
+
+ for (auto maskVal : convertedMasksBase)
+ convertedMasks.append(numNewChunks, maskVal);
+
+ for (auto [baseOffset, offsetType] :
+ llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ Value inc = arith::ConstantIndexOp::create(rewriter, loc,
+ i * blockedChunkSize);
+ Value incVec =
+ vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
+ Value offsetVal =
+ arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
+ convertedOffsets.push_back(offsetVal);
+ }
+ }
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedMasks =
+ pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
+
+ convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
+ convertedOffsets =
+ pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+
+ for (auto [v, o, m] :
+ llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
+ xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
+ rewriter.getI64IntegerAttr(chunkSize),
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
@@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns(
.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
- UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
+ UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
+ UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 6999da5d222fe..1392ded322d0b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -210,6 +210,27 @@ gpu.module @test {
gpu.return %ld : vector<32xf32>
}
+//-----
+
+
+ // CHECK-LABEL: load_with_offsets
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+ gpu.func @load_with_offsets(%src: memref<64xf32>) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+ %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+
+ gpu.return %ld : vector<32xf32>
+ }
+
//-----
// CHECK-LABEL: prefetch
@@ -254,6 +275,28 @@ gpu.module @test {
gpu.return
}
+
+ //-----
+
+ // CHECK-LABEL: store_with_offsets
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+ gpu.func @store_with_offsets(%src: memref<64xf32>) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+ xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+
+ gpu.return
+ }
//-----
// CHECK-LABEL: create_tdesc_step_chunk
@@ -319,6 +362,29 @@ gpu.module @test {
gpu.return %ld : vector<32x4xf32>
}
+//-----
+ // CHECK-LABEL: load_with_offsets_chunk
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
+ // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+ // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+ // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+ // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
+ gpu.func @load_with_offsets_chunk(%src: memref<64xf32>) -> vector<32x4xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+ %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
+ gpu.return %ld : vector<32x4xf32>
+ }
+
//-----
// CHECK-LABEL: store_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
@@ -342,6 +408,31 @@ gpu.module @test {
gpu.return
}
+//-----
+ // CHECK-LABEL: store_with_offsets_chunk
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
+ // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+ // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+ // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+ // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+ // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1>
+ gpu.func @store_with_offsets_chunk(%src: memref<64xf32>) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
+ xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1>
+ gpu.return
+ }
+
//-----
// CHECK-LABEL: prefetch_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e1ba45c60ac36..c83faea2e622c 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -50,49 +50,67 @@ struct TestXeGPUUnrollingPatterns
void runOnOperation() override {
MLIRContext *ctx = &getContext();
xegpu::UnrollOptions options;
- options.setNativeShapeFn(
- [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
- if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
- xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
- xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
- xegpu::TensorDescType tdescTy;
- if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
- tdescTy = createNdOp.getType();
- } else if (auto updateNdOp =
- dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
- tdescTy = updateNdOp.getTensorDescType();
- } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
- tdescTy = prefetchNdOp.getTensorDescType();
- } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
- tdescTy = loadNdOp.getTensorDescType();
- } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
- tdescTy = storeNdOp.getTensorDescType();
- } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
- tdescTy = createOp.getType();
- } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
- tdescTy = updateOp.getTensorDescType();
- } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
- tdescTy = prefetchOp.getTensorDescType();
- } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
- tdescTy = loadOp.getTensorDescType();
- } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
- tdescTy = storeOp.getTensorDescType();
+ options.setNativeShapeFn([&](Operation *op)
+ -> std::optional<SmallVector<int64_t>> {
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+ tdescTy = createNdOp.getType();
+ } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+ tdescTy = updateNdOp.getTensorDescType();
+ } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+ tdescTy = prefetchNdOp.getTensorDescType();
+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+ tdescTy = loadNdOp.getTensorDescType();
+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+ tdescTy = storeNdOp.getTensorDescType();
+ } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ tdescTy = createOp.getType();
+ } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ tdescTy = updateOp.getTensorDescType();
+ } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+ tdescTy = prefetchOp.getTensorDescType();
+ } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ if (loadOp.getOffsets()) {
+ auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult());
+ if (layout && layout.isForSubgroup()) {
+ auto inst_data = layout.getEffectiveInstDataAsInt();
+ if (!inst_data.empty())
+ return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
}
-
- if (auto layout = tdescTy.getLayoutAttr()) {
- auto inst_data = layout.getInstData();
- if (inst_data && layout.isForSubgroup())
- return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
- inst_data.asArrayRef().end());
+ return std::nullopt;
+ }
+ tdescTy = loadOp.getTensorDescType();
+ } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+ if (storeOp.getOffsets()) {
+ auto layout = llvm::dyn_cast_or_null<xegpu::LayoutAttr>(
+ op->getAttr("layout"));
+ if (layout && layout.isForSubgroup()) {
+ auto inst_data = layout.getEffectiveInstDataAsInt();
+ if (!inst_data.empty())
+ return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
}
+ return std::nullopt;
}
-
- if (isa<xegpu::DpasOp>(op))
- return SmallVector<int64_t>{8, 16, 16};
-
- return std::nullopt;
- });
+ tdescTy = storeOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isForSubgroup())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_...
[truncated]
|
pinging for reviews |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…th offsets (llvm#159453) This PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets.
This PR adds unrolling/blocking patterns for load_gather and store_scatter ops with offsets.