From 6cf90aa89cf82b446a2ec95ea1514d6de08f8a93 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 11 Sep 2025 04:14:03 +0000 Subject: [PATCH 1/6] Add unroll pattern for load gather --- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 64 +++++++++++-- .../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 21 +++++ .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 93 +++++++++++-------- 3 files changed, 131 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 29c9fcdfebcdb..c2dcb7120d7af 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -537,6 +537,58 @@ struct UnrollLoadGatherOp : public UnrollPattern { } }; +struct UnrollLoadGatherOpWithOffset + : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType valueTy = llvm::dyn_cast(op.getType()); + Value offsets = op.getOffsets(); + Value mask = op.getMask(); + + // Only handle the case where offsets are present (scattered load) + if (!offsets) + return failure(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape) + return failure(); + + // Unroll offsets + VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); + SmallVector convertedOffsetTypes = + getUnrolledTypes(offsetsTy, *targetShape); + SmallVector convertedOffsets = + pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + + // Unroll mask + VectorType maskTy = llvm::dyn_cast(mask.getType()); + SmallVector convertedMaskTypes = + getUnrolledTypes(maskTy, *targetShape); + SmallVector convertedMasks = + pack(mask, convertedMaskTypes, *targetShape, loc, rewriter); + + Type elemTy = valueTy.getElementType(); + VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + + SmallVector newOps; + auto chunkSizeAttr = + rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) { + auto newOp = xegpu::LoadGatherOp::create( + rewriter, loc, newValueTy, op.getSource(), o, m, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + newOp.dump(); + newOps.push_back(newOp); + } + + Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); + rewriter.replaceOp(op, castOp); + return success(); + } +}; + struct UnrollPrefetchOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::PrefetchOp op, @@ -762,10 +814,10 @@ struct UnrollStoreMatrixOp : public UnrollPattern { void mlir::xegpu::populateXeGPUUnrollPatterns( RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { - patterns - .add( - patterns.getContext(), options); + patterns.add( + patterns.getContext(), options); } diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 6999da5d222fe..20b8ddd873631 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -210,6 +210,27 @@ gpu.module @test { gpu.return %ld : vector<32xf32> } +//----- + + + // CHECK-LABEL: load_with_offsets + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32> + gpu.func @load_with_offsets(%src: memref<64xf32>) -> vector<32xf32> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32> + + gpu.return %ld : vector<32xf32> + } + //----- // CHECK-LABEL: prefetch diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 200323c7a4e51..cdd79d769402a 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -50,49 +50,60 @@ struct TestXeGPUUnrollingPatterns void runOnOperation() override { MLIRContext *ctx = &getContext(); xegpu::UnrollOptions options; - options.setNativeShapeFn( - [&](Operation *op) -> std::optional> { - if (isa(op)) { - xegpu::TensorDescType tdescTy; - if (auto createNdOp = dyn_cast(op)) { - tdescTy = createNdOp.getType(); - } else if (auto updateNdOp = - dyn_cast(op)) { - tdescTy = updateNdOp.getTensorDescType(); - } else if (auto prefetchNdOp = dyn_cast(op)) { - tdescTy = prefetchNdOp.getTensorDescType(); - } else if (auto loadNdOp = dyn_cast(op)) { - tdescTy = loadNdOp.getTensorDescType(); - } else if (auto storeNdOp = dyn_cast(op)) { - tdescTy = storeNdOp.getTensorDescType(); - } else if (auto createOp = dyn_cast(op)) { - tdescTy = createOp.getType(); - } else if (auto updateOp = dyn_cast(op)) { - tdescTy = updateOp.getTensorDescType(); - } else if (auto prefetchOp = dyn_cast(op)) { - tdescTy = prefetchOp.getTensorDescType(); - } else if (auto loadOp = dyn_cast(op)) { - tdescTy = loadOp.getTensorDescType(); - } else if (auto storeOp = dyn_cast(op)) { - tdescTy = storeOp.getTensorDescType(); - } - - if (auto layout = tdescTy.getLayoutAttr()) { - auto inst_data = layout.getInstData(); - if (inst_data && layout.isForSubgroup()) - return SmallVector(inst_data.asArrayRef().begin(), - inst_data.asArrayRef().end()); + options.setNativeShapeFn([&](Operation *op) + -> std::optional> { + if (isa(op)) { + xegpu::TensorDescType tdescTy; + if (auto createNdOp = dyn_cast(op)) { + tdescTy = createNdOp.getType(); + } else if (auto updateNdOp = dyn_cast(op)) { + tdescTy = updateNdOp.getTensorDescType(); + } else if (auto prefetchNdOp = dyn_cast(op)) { + tdescTy = prefetchNdOp.getTensorDescType(); + } else if (auto loadNdOp = dyn_cast(op)) { + tdescTy = loadNdOp.getTensorDescType(); + } else if (auto storeNdOp = dyn_cast(op)) { + tdescTy = storeNdOp.getTensorDescType(); + } else if (auto createOp = dyn_cast(op)) { + tdescTy = createOp.getType(); + } else if (auto updateOp = dyn_cast(op)) { + tdescTy = updateOp.getTensorDescType(); + } else if (auto prefetchOp = dyn_cast(op)) { + tdescTy = prefetchOp.getTensorDescType(); + } else if (auto loadOp = dyn_cast(op)) { + if (loadOp.getOffsets()) { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(loadOp.getResult()); + if (layout && layout.isForSubgroup()) { + auto inst_data = layout.getInstDataAsInt(); + if (!inst_data.empty()) + return SmallVector(inst_data.begin(), inst_data.end()); + } else { + return std::nullopt; } + } else if (!loadOp.getOffsets()) { + tdescTy = loadOp.getTensorDescType(); } - - if (isa(op)) - return SmallVector{8, 16, 16}; - - return std::nullopt; - }); + } else if (auto storeOp = dyn_cast(op)) { + tdescTy = storeOp.getTensorDescType(); + } + + if (auto layout = tdescTy.getLayoutAttr()) { + auto inst_data = layout.getInstData(); + if (inst_data && layout.isForSubgroup()) + return SmallVector(inst_data.asArrayRef().begin(), + inst_data.asArrayRef().end()); + } + } + + if (isa(op)) + return SmallVector{8, 16, 16}; + + return std::nullopt; + }); options.setUnrolledTypesFn( [&](ShapedType type, ArrayRef tileShape) -> SmallVector { From 4481eac930e67b39ce4efe3bdd3a108410d101ad Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 11 Sep 2025 17:05:42 +0000 Subject: [PATCH 2/6] Add unroll pattern for store scatter with offsets --- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 65 +++++++++++++++++-- .../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 22 +++++++ .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 14 +++- 3 files changed, 94 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index c2dcb7120d7af..a01086743f517 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -589,6 +589,58 @@ struct UnrollLoadGatherOpWithOffset } }; +struct UnrollStoreScatterOpWithOffsets + : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); + Value offsets = op.getOffsets(); + Value mask = op.getMask(); + + // Only handle the case where offsets are present (scattered store) + if (!offsets) + return failure(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape) + return failure(); + + // Unroll offsets + VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); + SmallVector convertedOffsetTypes = + getUnrolledTypes(offsetsTy, *targetShape); + SmallVector convertedOffsets = + pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + + // Unroll mask + VectorType maskTy = llvm::dyn_cast(mask.getType()); + SmallVector convertedMaskTypes = + getUnrolledTypes(maskTy, *targetShape); + SmallVector convertedMasks = + pack(mask, convertedMaskTypes, *targetShape, loc, rewriter); + + // Unroll value + SmallVector convertedValTypes = + getUnrolledTypes(valueTy, *targetShape); + SmallVector convertedValues = + pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + + auto chunkSizeAttr = + rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + for (auto [v, o, m] : + llvm::zip(convertedValues, convertedOffsets, convertedMasks)) { + xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, + chunkSizeAttr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + } + + rewriter.eraseOp(op); + return success(); + } +}; + struct UnrollPrefetchOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::PrefetchOp op, @@ -814,10 +866,11 @@ struct UnrollStoreMatrixOp : public UnrollPattern { void mlir::xegpu::populateXeGPUUnrollPatterns( RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { - patterns.add( - patterns.getContext(), options); + patterns + .add( + patterns.getContext(), options); } diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 20b8ddd873631..33973b4b0f2e6 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -275,6 +275,28 @@ gpu.module @test { gpu.return } + + //----- + + // CHECK-LABEL: store_with_offsets + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : vector<16xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets(%src: memref<64xf32>) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + + %st_vec = arith.constant dense<1023.0>: vector<32xf32> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1> + + gpu.return + } //----- // CHECK-LABEL: create_tdesc_step_chunk diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index cdd79d769402a..1c241fa05c8d9 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -88,7 +88,19 @@ struct TestXeGPUUnrollingPatterns tdescTy = loadOp.getTensorDescType(); } } else if (auto storeOp = dyn_cast(op)) { - tdescTy = storeOp.getTensorDescType(); + if (storeOp.getOffsets()) { + auto layout = llvm::dyn_cast_or_null( + op->getAttr("layout")); + if (layout && layout.isForSubgroup()) { + auto inst_data = layout.getInstDataAsInt(); + if (!inst_data.empty()) + return SmallVector(inst_data.begin(), inst_data.end()); + } else { + return std::nullopt; + } + } else if (!storeOp.getOffsets()) { + tdescTy = storeOp.getTensorDescType(); + } } if (auto layout = tdescTy.getLayoutAttr()) { From 97a02872a418c3cec01516a6e569295ebdccfeca Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Sep 2025 05:32:19 +0000 Subject: [PATCH 3/6] Handle chunk size --- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 151 ++++++++++++++---- .../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 48 ++++++ 2 files changed, 166 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a01086743f517..d7585fa5df8b3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -537,6 +537,10 @@ struct UnrollLoadGatherOp : public UnrollPattern { } }; +/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered +/// load). +/// It unrolls the offsets and mask operands accordingly, and creates multiple +/// LoadGatherOp with the unrolled operands. struct UnrollLoadGatherOpWithOffset : public UnrollPattern { using UnrollPattern::UnrollPattern; @@ -555,31 +559,70 @@ struct UnrollLoadGatherOpWithOffset if (!targetShape) return failure(); - // Unroll offsets - VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); - SmallVector convertedOffsetTypes = - getUnrolledTypes(offsetsTy, *targetShape); - SmallVector convertedOffsets = - pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + SmallVector targetMaskShape(*targetShape); + int64_t chunkSize = 1; + if (auto chunkSizeAttr = op->getAttr("chunk_size")) { + if (auto intAttr = llvm::dyn_cast(chunkSizeAttr)) + chunkSize = intAttr.getInt(); + } - // Unroll mask + // Unroll mask and offsets with correct shape VectorType maskTy = llvm::dyn_cast(mask.getType()); - SmallVector convertedMaskTypes = - getUnrolledTypes(maskTy, *targetShape); - SmallVector convertedMasks = - pack(mask, convertedMaskTypes, *targetShape, loc, rewriter); - + VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); Type elemTy = valueTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + SmallVector convertedMaskTypes; + SmallVector convertedMasks; + SmallVector convertedOffsetTypes; + SmallVector convertedOffsets; + + if (chunkSize > 1) { + // For chunked loads, mask and offsets have one less dimension + targetMaskShape.pop_back(); + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = chunkSize / blockedChunkSize; + chunkSize = blockedChunkSize; + + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape); + + SmallVector convertedMasksBase = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + SmallVector convertedOffsetsBase = + pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter); + + for (auto maskVal : convertedMasksBase) + convertedMasks.append(numNewChunks, maskVal); + + for (auto [baseOffset, offsetType] : + llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) { + for (int64_t i = 0; i < numNewChunks; ++i) { + Value inc = arith::ConstantIndexOp::create(rewriter, loc, + i * blockedChunkSize); + Value incVec = + vector::BroadcastOp::create(rewriter, loc, offsetType, inc); + Value offsetVal = + arith::AddIOp::create(rewriter, loc, baseOffset, incVec); + convertedOffsets.push_back(offsetVal); + } + } + } else { + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedMasks = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + + convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape); + convertedOffsets = + pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + } + SmallVector newOps; - auto chunkSizeAttr = - rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) { auto newOp = xegpu::LoadGatherOp::create( - rewriter, loc, newValueTy, op.getSource(), o, m, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - newOp.dump(); + rewriter, loc, newValueTy, op.getSource(), o, m, + rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); newOps.push_back(newOp); } @@ -589,6 +632,10 @@ struct UnrollLoadGatherOpWithOffset } }; +/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered +/// store). +/// It unrolls the offsets and mask operands accordingly, and creates multiple +/// StoreScatterOp with the unrolled operands. struct UnrollStoreScatterOpWithOffsets : public UnrollPattern { using UnrollPattern::UnrollPattern; @@ -607,33 +654,71 @@ struct UnrollStoreScatterOpWithOffsets if (!targetShape) return failure(); - // Unroll offsets - VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); - SmallVector convertedOffsetTypes = - getUnrolledTypes(offsetsTy, *targetShape); - SmallVector convertedOffsets = - pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + int64_t chunkSize = 1; + if (auto chunkSizeAttr = op->getAttr("chunk_size")) { + if (auto intAttr = llvm::dyn_cast(chunkSizeAttr)) + chunkSize = intAttr.getInt(); + } - // Unroll mask + SmallVector targetMaskShape(*targetShape); VectorType maskTy = llvm::dyn_cast(mask.getType()); - SmallVector convertedMaskTypes = - getUnrolledTypes(maskTy, *targetShape); - SmallVector convertedMasks = - pack(mask, convertedMaskTypes, *targetShape, loc, rewriter); + VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); + + SmallVector convertedMaskTypes; + SmallVector convertedMasks; + SmallVector convertedOffsetTypes; + SmallVector convertedOffsets; + + if (chunkSize > 1) { + targetMaskShape.pop_back(); + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = chunkSize / blockedChunkSize; + chunkSize = blockedChunkSize; + + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape); + + SmallVector convertedMasksBase = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + SmallVector convertedOffsetsBase = + pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter); + + for (auto maskVal : convertedMasksBase) + convertedMasks.append(numNewChunks, maskVal); + + for (auto [baseOffset, offsetType] : + llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) { + for (int64_t i = 0; i < numNewChunks; ++i) { + Value inc = arith::ConstantIndexOp::create(rewriter, loc, + i * blockedChunkSize); + Value incVec = + vector::BroadcastOp::create(rewriter, loc, offsetType, inc); + Value offsetVal = + arith::AddIOp::create(rewriter, loc, baseOffset, incVec); + convertedOffsets.push_back(offsetVal); + } + } + } else { + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedMasks = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + + convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape); + convertedOffsets = + pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + } - // Unroll value SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); - auto chunkSizeAttr = - rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); for (auto [v, o, m] : llvm::zip(convertedValues, convertedOffsets, convertedMasks)) { xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, - chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + rewriter.getI64IntegerAttr(chunkSize), + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); } rewriter.eraseOp(op); diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 33973b4b0f2e6..1392ded322d0b 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -362,6 +362,29 @@ gpu.module @test { gpu.return %ld : vector<32x4xf32> } +//----- + // CHECK-LABEL: load_with_offsets_chunk + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32> + // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> + // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> + // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> + // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16x2xf32> + gpu.func @load_with_offsets_chunk(%src: memref<64xf32>) -> vector<32x4xf32> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32x4xf32> + gpu.return %ld : vector<32x4xf32> + } + //----- // CHECK-LABEL: store_chunk // CHECK-SAME: [[arg0:%.+]]: ui64 @@ -385,6 +408,31 @@ gpu.module @test { gpu.return } +//----- + // CHECK-LABEL: store_with_offsets_chunk + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32 + // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> + // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> + // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> + // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : vector<16x2xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets_chunk(%src: memref<64xf32>) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + + %st_vec = arith.constant dense<1023.>: vector<32x4xf32> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32x4xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1> + gpu.return + } + //----- // CHECK-LABEL: prefetch_chunk // CHECK-SAME: [[arg0:%.+]]: ui64 From a7fb584599edd8bec942b2e17ce8d56fd35764c9 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 17 Sep 2025 21:23:44 +0000 Subject: [PATCH 4/6] fix merge --- mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 945dd33ebe53a..c7d012424f08b 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -78,7 +78,7 @@ struct TestXeGPUUnrollingPatterns xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(loadOp.getResult()); if (layout && layout.isForSubgroup()) { - auto inst_data = layout.getInstDataAsInt(); + auto inst_data = layout.getEffectiveInstDataAsInt(); if (!inst_data.empty()) return SmallVector(inst_data.begin(), inst_data.end()); } else { @@ -92,7 +92,7 @@ struct TestXeGPUUnrollingPatterns auto layout = llvm::dyn_cast_or_null( op->getAttr("layout")); if (layout && layout.isForSubgroup()) { - auto inst_data = layout.getInstDataAsInt(); + auto inst_data = layout.getEffectiveInstDataAsInt(); if (!inst_data.empty()) return SmallVector(inst_data.begin(), inst_data.end()); } else { From c240c78a26e5e82585936643489ba32ebdba7fc0 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 18 Sep 2025 22:38:49 +0000 Subject: [PATCH 5/6] Clean up --- .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index c7d012424f08b..c83faea2e622c 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -75,18 +75,15 @@ struct TestXeGPUUnrollingPatterns tdescTy = prefetchOp.getTensorDescType(); } else if (auto loadOp = dyn_cast(op)) { if (loadOp.getOffsets()) { - xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(loadOp.getResult()); + auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult()); if (layout && layout.isForSubgroup()) { auto inst_data = layout.getEffectiveInstDataAsInt(); if (!inst_data.empty()) return SmallVector(inst_data.begin(), inst_data.end()); - } else { - return std::nullopt; } - } else if (!loadOp.getOffsets()) { - tdescTy = loadOp.getTensorDescType(); + return std::nullopt; } + tdescTy = loadOp.getTensorDescType(); } else if (auto storeOp = dyn_cast(op)) { if (storeOp.getOffsets()) { auto layout = llvm::dyn_cast_or_null( @@ -95,12 +92,10 @@ struct TestXeGPUUnrollingPatterns auto inst_data = layout.getEffectiveInstDataAsInt(); if (!inst_data.empty()) return SmallVector(inst_data.begin(), inst_data.end()); - } else { - return std::nullopt; } - } else if (!storeOp.getOffsets()) { - tdescTy = storeOp.getTensorDescType(); + return std::nullopt; } + tdescTy = storeOp.getTensorDescType(); } if (auto layout = tdescTy.getLayoutAttr()) { From c41b8b362a877fa7e53cdf9eb4c20637d0d97d70 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 23 Sep 2025 22:50:04 +0000 Subject: [PATCH 6/6] Fix test --- .../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 1392ded322d0b..dbc52b8a98894 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -214,9 +214,9 @@ gpu.module @test { // CHECK-LABEL: load_with_offsets - // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> - // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32> - gpu.func @load_with_offsets(%src: memref<64xf32>) -> vector<32xf32> { + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32> + gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -226,7 +226,7 @@ gpu.module @test { %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> - %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32> gpu.return %ld : vector<32xf32> } @@ -279,9 +279,9 @@ gpu.module @test { //----- // CHECK-LABEL: store_with_offsets - // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> - // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : vector<16xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1> - gpu.func @store_with_offsets(%src: memref<64xf32>) { + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets(%src: ui64) { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -293,7 +293,7 @@ gpu.module @test { %mask = vector.create_mask %c17: vector<32xi1> %st_vec = arith.constant dense<1023.0>: vector<32xf32> - xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1> gpu.return } @@ -364,14 +364,14 @@ gpu.module @test { //----- // CHECK-LABEL: load_with_offsets_chunk - // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32> // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : memref<64xf32>, vector<16xindex>, vector<16xi1> -> vector<16x2xf32> - gpu.func @load_with_offsets_chunk(%src: memref<64xf32>) -> vector<32x4xf32> { + // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32> + gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -381,7 +381,7 @@ gpu.module @test { %c17 = arith.constant 17: index %mask = vector.create_mask %c17: vector<32xi1> - %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : memref<64xf32>, vector<32xindex>, vector<32xi1> -> vector<32x4xf32> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32> gpu.return %ld : vector<32x4xf32> } @@ -410,14 +410,14 @@ gpu.module @test { //----- // CHECK-LABEL: store_with_offsets_chunk - // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32 // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : vector<16x2xf32>, memref<64xf32>, vector<16xindex>, vector<16xi1> - gpu.func @store_with_offsets_chunk(%src: memref<64xf32>) { + // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets_chunk(%src: ui64) { %cst = arith.constant dense<[ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, @@ -429,7 +429,7 @@ gpu.module @test { %mask = vector.create_mask %c17: vector<32xi1> %st_vec = arith.constant dense<1023.>: vector<32x4xf32> - xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32x4xf32>, memref<64xf32>, vector<32xindex>, vector<32xi1> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1> gpu.return }