From 2e8e6d71319532629efa707880ad14329fd86020 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Mon, 6 Oct 2025 14:37:21 +0000 Subject: [PATCH] [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax Signed-off-by: dchigarev --- .../VectorToXeGPU/VectorToXeGPU.cpp | 218 ++++++++++++------ mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 12 +- .../VectorToXeGPU/load-to-xegpu.mlir | 4 +- .../VectorToXeGPU/store-to-xegpu.mlir | 4 +- .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 8 +- .../transfer-write-to-xegpu.mlir | 4 +- 6 files changed, 171 insertions(+), 79 deletions(-) diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index e2c7d803e5a5e..41526a7e34971 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -97,6 +97,64 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, return success(); } +static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc, + SmallVector &mixedShapes, + SmallVector &mixedStrides, + SmallVector &strides, + TypedValue src) { + auto srcTy = src.getType(); + // In case of any dynamic shapes, source's shape and strides have to be + // explicitly provided. + SmallVector sourceDims; + unsigned srcRank = srcTy.getRank(); + for (unsigned i = 0; i < srcRank; ++i) + sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); + + for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { + if (shape == ShapedType::kDynamic) + mixedShapes.push_back(sourceDims[idx]); + else + mixedShapes.push_back(rewriter.getI64IntegerAttr(shape)); + } + + // Compute strides in reverse order. + Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); + // Last stride is guaranteed to be static and unit. + mixedStrides.push_back(rewriter.getI64IntegerAttr(1)); + for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { + accStride = + arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); + if (strides[i] == ShapedType::kDynamic) + mixedStrides.push_back(accStride); + else + mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i])); + } + std::reverse(mixedStrides.begin(), mixedStrides.end()); +} + +static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, + Location loc, + xegpu::TensorDescType descType, + TypedValue src) { + MemRefType srcTy = src.getType(); + auto [strides, offset] = srcTy.getStridesAndOffset(); + + xegpu::CreateNdDescOp ndDesc; + if (srcTy.hasStaticShape()) + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src); + else { + SmallVector mixedShapes; + SmallVector mixedStrides; + computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides, + src); + + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, + mixedShapes, mixedStrides); + } + + return ndDesc; +} + static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::TensorDescType descType, TypedValue src, @@ -109,45 +167,22 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, getAsOpFoldResult(offsets)); } else { - // In case of any dynamic shapes, source's shape and strides have to be - // explicitly provided. - SmallVector sourceDims; - unsigned srcRank = srcTy.getRank(); - for (unsigned i = 0; i < srcRank; ++i) - sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); - - SmallVector constOffsets; - SmallVector dynOffsets; + SmallVector mixedOffsets; for (Value offset : offsets) { std::optional staticVal = getConstantIntValue(offset); - if (!staticVal) - dynOffsets.push_back(offset); - constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); - } - - SmallVector dynShapes; - for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { - if (shape == ShapedType::kDynamic) - dynShapes.push_back(sourceDims[idx]); + if (staticVal) + mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value())); + else + mixedOffsets.push_back(offset); } - // Compute strides in reverse order. - SmallVector dynStrides; - Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); - // Last stride is guaranteed to be static and unit. - for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { - accStride = - arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); - if (strides[i] == ShapedType::kDynamic) - dynStrides.push_back(accStride); - } - std::reverse(dynStrides.begin(), dynStrides.end()); + SmallVector mixedShapes; + SmallVector mixedStrides; + computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides, + src); ndDesc = xegpu::CreateNdDescOp::create( - rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, - DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), - DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), - DenseI64ArrayAttr::get(rewriter.getContext(), strides)); + rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides); } return ndDesc; @@ -523,21 +558,35 @@ struct TransferReadLowering : public OpRewritePattern { descShape, elementType, /*array_length=*/1, /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = - createNdDescriptor(rewriter, loc, descType, - dyn_cast>(readOp.getBase()), - readOp.getIndices()); - DenseI64ArrayAttr transposeAttr = !isTransposeLoad ? nullptr : DenseI64ArrayAttr::get(rewriter.getContext(), ArrayRef{1, 0}); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, - /*packed=*/nullptr, transposeAttr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::LoadNdOp loadOp; + + if (vecTy.getRank() == readOp.getBase().getType().getRank()) { + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, + dyn_cast>(readOp.getBase())); + + loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + getAsOpFoldResult(readOp.getIndices()), + /*packed=*/nullptr, transposeAttr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } else { + xegpu::CreateNdDescOp ndDesc = + createNdDescriptor(rewriter, loc, descType, + dyn_cast>(readOp.getBase()), + readOp.getIndices()); + + loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + /*packed=*/nullptr, transposeAttr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } rewriter.replaceOp(readOp, loadOp); return success(); @@ -579,17 +628,30 @@ struct TransferWriteLowering vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = - createNdDescriptor(rewriter, loc, descType, - dyn_cast>(writeOp.getBase()), - writeOp.getIndices()); - // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto storeOp = - xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::StoreNdOp storeOp; + if (vecTy.getRank() == writeOp.getBase().getType().getRank()) { + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, + dyn_cast>(writeOp.getBase())); + + storeOp = + xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, + getAsOpFoldResult(writeOp.getIndices()), + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } else { + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, + dyn_cast>(writeOp.getBase()), + writeOp.getIndices()); + + storeOp = + xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } rewriter.replaceOp(writeOp, storeOp); return success(); @@ -674,19 +736,32 @@ struct LoadLowering : public OpRewritePattern { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; + // By default, no specific caching policy is assigned. + xegpu::CachePolicyAttr hint = nullptr; auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = createNdDescriptor( - rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); - // By default, no specific caching policy is assigned. - xegpu::CachePolicyAttr hint = nullptr; - auto loadNdOp = xegpu::LoadNdOp::create( - rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::LoadNdOp loadNdOp; + + if (vecTy.getRank() == loadOp.getBase().getType().getRank()) { + xegpu::CreateNdDescOp ndDesc = + createNdDescriptor(rewriter, loc, descType, loadOp.getBase()); + loadNdOp = xegpu::LoadNdOp::create( + rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()), + /*packed=*/nullptr, /*transpose=*/nullptr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } else { + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); + loadNdOp = + xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + /*packed=*/nullptr, /*transpose=*/nullptr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } rewriter.replaceOp(loadOp, loadNdOp); return success(); @@ -711,15 +786,28 @@ struct StoreLowering : public OpRewritePattern { auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = createNdDescriptor( - rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto storeNdOp = - xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::StoreNdOp storeNdOp; + if (vecTy.getRank() == storeOp.getBase().getType().getRank()) { + xegpu::CreateNdDescOp ndDesc = + createNdDescriptor(rewriter, loc, descType, storeOp.getBase()); + + storeNdOp = + xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + getAsOpFoldResult(storeOp.getIndices()), + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } else { + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); + + storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + } + rewriter.replaceOp(storeOp, storeNdOp); return success(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788d0b9b4..01a0a63d95bef 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -215,8 +215,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -277,8 +279,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index 9908205f07c92..c7c0485768b99 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: %[[SRC]] // CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index 2c498dcc2a071..19240abe1e75c 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: %[[SRC]] // CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32> -// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index c4ca79af1bd9a..72bdab0a4db3a 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -83,9 +83,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>, // LOAD-ND-LABEL: @load_zero_pad_out_of_bounds( // LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>, // LOAD-ND-SAME: %[[OFFSET:.+]]: index -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]] // LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32> -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // LOAD-ND: return %[[VEC]] // LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds( @@ -109,9 +109,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>, // LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>, // LOAD-ND-SAME: %[[OFFSET1:.+]]: index, // LOAD-ND-SAME: %[[OFFSET2:.+]]: index -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]] +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]] // LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32 -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array}> +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array}> // LOAD-ND-SAME: -> vector<8x16xf32> // LOAD-ND: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index fcfc9414da4f6..ca3bbc11a5180 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -126,9 +126,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>, // STORE-ND-SAME: %[[OFFSET:.+]]: index // STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// STORE-ND-SAME: %[[SRC]] // STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32> -// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> +// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // STORE-SCATTER-LABEL: @store_out_of_bounds( // STORE-SCATTER: vector.transfer_write