diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 29c9fcdfebcdb..d7585fa5df8b3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern { } }; +/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered +/// load). +/// It unrolls the offsets and mask operands accordingly, and creates multiple +/// LoadGatherOp with the unrolled operands. +struct UnrollLoadGatherOpWithOffset + : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType valueTy = llvm::dyn_cast(op.getType()); + Value offsets = op.getOffsets(); + Value mask = op.getMask(); + + // Only handle the case where offsets are present (scattered load) + if (!offsets) + return failure(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape) + return failure(); + + SmallVector targetMaskShape(*targetShape); + int64_t chunkSize = 1; + if (auto chunkSizeAttr = op->getAttr("chunk_size")) { + if (auto intAttr = llvm::dyn_cast(chunkSizeAttr)) + chunkSize = intAttr.getInt(); + } + + // Unroll mask and offsets with correct shape + VectorType maskTy = llvm::dyn_cast(mask.getType()); + VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); + Type elemTy = valueTy.getElementType(); + VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + + SmallVector convertedMaskTypes; + SmallVector convertedMasks; + SmallVector convertedOffsetTypes; + SmallVector convertedOffsets; + + if (chunkSize > 1) { + // For chunked loads, mask and offsets have one less dimension + targetMaskShape.pop_back(); + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = chunkSize / blockedChunkSize; + chunkSize = blockedChunkSize; + + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape); + + SmallVector convertedMasksBase = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + SmallVector convertedOffsetsBase = + pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter); + + for (auto maskVal : convertedMasksBase) + convertedMasks.append(numNewChunks, maskVal); + + for (auto [baseOffset, offsetType] : + llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) { + for (int64_t i = 0; i < numNewChunks; ++i) { + Value inc = arith::ConstantIndexOp::create(rewriter, loc, + i * blockedChunkSize); + Value incVec = + vector::BroadcastOp::create(rewriter, loc, offsetType, inc); + Value offsetVal = + arith::AddIOp::create(rewriter, loc, baseOffset, incVec); + convertedOffsets.push_back(offsetVal); + } + } + } else { + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedMasks = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + + convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape); + convertedOffsets = + pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + } + + SmallVector newOps; + for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) { + auto newOp = xegpu::LoadGatherOp::create( + rewriter, loc, newValueTy, op.getSource(), o, m, + rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + newOps.push_back(newOp); + } + + Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered +/// store). +/// It unrolls the offsets and mask operands accordingly, and creates multiple +/// StoreScatterOp with the unrolled operands. +struct UnrollStoreScatterOpWithOffsets + : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); + Value offsets = op.getOffsets(); + Value mask = op.getMask(); + + // Only handle the case where offsets are present (scattered store) + if (!offsets) + return failure(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape) + return failure(); + + int64_t chunkSize = 1; + if (auto chunkSizeAttr = op->getAttr("chunk_size")) { + if (auto intAttr = llvm::dyn_cast(chunkSizeAttr)) + chunkSize = intAttr.getInt(); + } + + SmallVector targetMaskShape(*targetShape); + VectorType maskTy = llvm::dyn_cast(mask.getType()); + VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); + + SmallVector convertedMaskTypes; + SmallVector convertedMasks; + SmallVector convertedOffsetTypes; + SmallVector convertedOffsets; + + if (chunkSize > 1) { + targetMaskShape.pop_back(); + int64_t blockedChunkSize = targetShape->back(); + int64_t numNewChunks = chunkSize / blockedChunkSize; + chunkSize = blockedChunkSize; + + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape); + + SmallVector convertedMasksBase = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + SmallVector convertedOffsetsBase = + pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter); + + for (auto maskVal : convertedMasksBase) + convertedMasks.append(numNewChunks, maskVal); + + for (auto [baseOffset, offsetType] : + llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) { + for (int64_t i = 0; i < numNewChunks; ++i) { + Value inc = arith::ConstantIndexOp::create(rewriter, loc, + i * blockedChunkSize); + Value incVec = + vector::BroadcastOp::create(rewriter, loc, offsetType, inc); + Value offsetVal = + arith::AddIOp::create(rewriter, loc, baseOffset, incVec); + convertedOffsets.push_back(offsetVal); + } + } + } else { + convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); + convertedMasks = + pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); + + convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape); + convertedOffsets = + pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); + } + + SmallVector convertedValTypes = + getUnrolledTypes(valueTy, *targetShape); + SmallVector convertedValues = + pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + + for (auto [v, o, m] : + llvm::zip(convertedValues, convertedOffsets, convertedMasks)) { + xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, + rewriter.getI64IntegerAttr(chunkSize), + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + + rewriter.eraseOp(op); + return success(); + } +}; + struct UnrollPrefetchOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::PrefetchOp op, @@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns( .add( + UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp, + UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>( patterns.getContext(), options); } diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir index 6999da5d222fe..dbc52b8a98894 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -210,6 +210,27 @@ gpu.module @test { gpu.return %ld : vector<32xf32> } +//----- + + + // CHECK-LABEL: load_with_offsets + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32> + gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32> + + gpu.return %ld : vector<32xf32> + } + //----- // CHECK-LABEL: prefetch @@ -254,6 +275,28 @@ gpu.module @test { gpu.return } + + //----- + + // CHECK-LABEL: store_with_offsets + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets(%src: ui64) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + + %st_vec = arith.constant dense<1023.0>: vector<32xf32> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1> + + gpu.return + } //----- // CHECK-LABEL: create_tdesc_step_chunk @@ -319,6 +362,29 @@ gpu.module @test { gpu.return %ld : vector<32x4xf32> } +//----- + // CHECK-LABEL: load_with_offsets_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32> + // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> + // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> + // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> + // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32> + gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32> + gpu.return %ld : vector<32x4xf32> + } + //----- // CHECK-LABEL: store_chunk // CHECK-SAME: [[arg0:%.+]]: ui64 @@ -342,6 +408,31 @@ gpu.module @test { gpu.return } +//----- + // CHECK-LABEL: store_with_offsets_chunk + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32 + // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> + // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> + // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> + // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets_chunk(%src: ui64) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + + %st_vec = arith.constant dense<1023.>: vector<32x4xf32> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1> + gpu.return + } + //----- // CHECK-LABEL: prefetch_chunk // CHECK-SAME: [[arg0:%.+]]: ui64 diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index e1ba45c60ac36..c83faea2e622c 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -50,49 +50,67 @@ struct TestXeGPUUnrollingPatterns void runOnOperation() override { MLIRContext *ctx = &getContext(); xegpu::UnrollOptions options; - options.setNativeShapeFn( - [&](Operation *op) -> std::optional> { - if (isa(op)) { - xegpu::TensorDescType tdescTy; - if (auto createNdOp = dyn_cast(op)) { - tdescTy = createNdOp.getType(); - } else if (auto updateNdOp = - dyn_cast(op)) { - tdescTy = updateNdOp.getTensorDescType(); - } else if (auto prefetchNdOp = dyn_cast(op)) { - tdescTy = prefetchNdOp.getTensorDescType(); - } else if (auto loadNdOp = dyn_cast(op)) { - tdescTy = loadNdOp.getTensorDescType(); - } else if (auto storeNdOp = dyn_cast(op)) { - tdescTy = storeNdOp.getTensorDescType(); - } else if (auto createOp = dyn_cast(op)) { - tdescTy = createOp.getType(); - } else if (auto updateOp = dyn_cast(op)) { - tdescTy = updateOp.getTensorDescType(); - } else if (auto prefetchOp = dyn_cast(op)) { - tdescTy = prefetchOp.getTensorDescType(); - } else if (auto loadOp = dyn_cast(op)) { - tdescTy = loadOp.getTensorDescType(); - } else if (auto storeOp = dyn_cast(op)) { - tdescTy = storeOp.getTensorDescType(); + options.setNativeShapeFn([&](Operation *op) + -> std::optional> { + if (isa(op)) { + xegpu::TensorDescType tdescTy; + if (auto createNdOp = dyn_cast(op)) { + tdescTy = createNdOp.getType(); + } else if (auto updateNdOp = dyn_cast(op)) { + tdescTy = updateNdOp.getTensorDescType(); + } else if (auto prefetchNdOp = dyn_cast(op)) { + tdescTy = prefetchNdOp.getTensorDescType(); + } else if (auto loadNdOp = dyn_cast(op)) { + tdescTy = loadNdOp.getTensorDescType(); + } else if (auto storeNdOp = dyn_cast(op)) { + tdescTy = storeNdOp.getTensorDescType(); + } else if (auto createOp = dyn_cast(op)) { + tdescTy = createOp.getType(); + } else if (auto updateOp = dyn_cast(op)) { + tdescTy = updateOp.getTensorDescType(); + } else if (auto prefetchOp = dyn_cast(op)) { + tdescTy = prefetchOp.getTensorDescType(); + } else if (auto loadOp = dyn_cast(op)) { + if (loadOp.getOffsets()) { + auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult()); + if (layout && layout.isForSubgroup()) { + auto inst_data = layout.getEffectiveInstDataAsInt(); + if (!inst_data.empty()) + return SmallVector(inst_data.begin(), inst_data.end()); } - - if (auto layout = tdescTy.getLayoutAttr()) { - auto inst_data = layout.getInstData(); - if (inst_data && layout.isForSubgroup()) - return SmallVector(inst_data.asArrayRef().begin(), - inst_data.asArrayRef().end()); + return std::nullopt; + } + tdescTy = loadOp.getTensorDescType(); + } else if (auto storeOp = dyn_cast(op)) { + if (storeOp.getOffsets()) { + auto layout = llvm::dyn_cast_or_null( + op->getAttr("layout")); + if (layout && layout.isForSubgroup()) { + auto inst_data = layout.getEffectiveInstDataAsInt(); + if (!inst_data.empty()) + return SmallVector(inst_data.begin(), inst_data.end()); } + return std::nullopt; } - - if (isa(op)) - return SmallVector{8, 16, 16}; - - return std::nullopt; - }); + tdescTy = storeOp.getTensorDescType(); + } + + if (auto layout = tdescTy.getLayoutAttr()) { + auto inst_data = layout.getInstData(); + if (inst_data && layout.isForSubgroup()) + return SmallVector(inst_data.asArrayRef().begin(), + inst_data.asArrayRef().end()); + } + } + + if (isa(op)) + return SmallVector{8, 16, 16}; + + return std::nullopt; + }); options.setUnrolledTypesFn( [&](ShapedType type, ArrayRef tileShape) -> SmallVector {