diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 852c322cc6467..9f5585a701438 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -183,11 +183,15 @@ static void adjustStridesForPermutation(AffineMap permMap, // Computes memory strides and a memref offset for vector transfer operations, // handling both static and dynamic memrefs while applying permutation // transformations for XeGPU lowering. +template < + typename OpType, + typename = std::enable_if_t, vector::TransferReadOp, vector::TransferWriteOp, + vector::GatherOp, vector::ScatterOp>::value>> static std::pair, Value> -computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { +computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) { SmallVector strides; Value baseMemref = xferOp.getBase(); - AffineMap permMap = xferOp.getPermutationMap(); MemRefType memrefType = dyn_cast(baseMemref.getType()); Location loc = xferOp.getLoc(); @@ -197,9 +201,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { SmallVector intStrides; if (failed(memrefType.getStridesAndOffset(intStrides, offset))) return {{}, offsetVal}; - // Wrap static strides as MLIR values - for (int64_t s : intStrides) - strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); + bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) { + return ShapedType::isDynamic(strideVal); + }); + + if (!hasDynamicStrides) + for (int64_t s : intStrides) + strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); + if (!ShapedType::isDynamic(offset)) offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset); } @@ -232,8 +241,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { if (!offsetVal) offsetVal = meta.getOffset(); } - // Adjust strides according to the permutation map (e.g., for transpose) - adjustStridesForPermutation(permMap, strides); + + if constexpr (llvm::is_one_of, vector::TransferReadOp, + vector::TransferWriteOp>::value) { + AffineMap permMap = xferOp.getPermutationMap(); + // Adjust strides according to the permutation map (e.g., for transpose) + adjustStridesForPermutation(permMap, strides); + } + return {strides, offsetVal}; } @@ -339,9 +354,51 @@ static Value computeOffsets(VectorTransferOpInterface xferOp, return localOffsets; } +// Compute the element-wise offsets for vector.gather or vector.scatter ops. +// +// This function linearizes the base offsets of the gather/scatter operation +// and combines them with the per-element indices to produce a final vector of +// memory offsets. +template < + typename OpType, + typename = std::enable_if_t, vector::GatherOp, vector::ScatterOp>::value>> +static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp, + ArrayRef strides, Value baseOffset) { + Location loc = gatScatOp.getLoc(); + SmallVector offsets = gatScatOp.getOffsets(); + for (size_t i = 0; i < offsets.size(); ++i) { + Value offsetContrib = + arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]); + baseOffset = + arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); + } + Value indices = gatScatOp.getIndices(); + VectorType vecType = cast(indices.getType()); + + Value strideVector = + vector::BroadcastOp::create(rewriter, loc, vecType, strides.back()) + .getResult(); + Value stridedIndices = + arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult(); + + Value baseVector = + vector::BroadcastOp::create( + rewriter, loc, + VectorType::get(vecType.getShape(), rewriter.getIndexType()), + baseOffset) + .getResult(); + return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices) + .getResult(); +} + +template < + typename OpType, + typename = std::enable_if_t, vector::TransferReadOp, vector::TransferWriteOp, + vector::GatherOp, vector::ScatterOp>::value>> // Convert memref to i64 base pointer -static Value memrefToIndexPtr(VectorTransferOpInterface xferOp, - PatternRewriter &rewriter) { +static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) { Location loc = xferOp.getLoc(); auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create( rewriter, loc, xferOp.getBase()) @@ -539,6 +596,71 @@ struct TransferWriteLowering } }; +struct GatherLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::GatherOp gatherOp, + PatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(gatherOp.getBase().getType()); + if (!srcTy) + return rewriter.notifyMatchFailure(gatherOp, "Expects memref source"); + + Location loc = gatherOp.getLoc(); + VectorType vectorType = gatherOp.getVectorType(); + + auto meta = computeMemrefMeta(gatherOp, rewriter); + if (meta.first.empty()) + return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides"); + + Value localOffsets = + computeOffsets(rewriter, gatherOp, meta.first, meta.second); + Value flatMemref = memrefToIndexPtr(gatherOp, rewriter); + + auto xeGatherOp = xegpu::LoadGatherOp::create( + rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(), + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/xegpu::CachePolicyAttr{}, + /*l2_hint=*/xegpu::CachePolicyAttr{}, + /*l3_hint=*/xegpu::CachePolicyAttr{}); + + auto selectOp = + arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), + xeGatherOp.getResult(), gatherOp.getPassThru()); + rewriter.replaceOp(gatherOp, selectOp.getResult()); + return success(); + } +}; + +struct ScatterLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ScatterOp scatterOp, + PatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(scatterOp.getBase().getType()); + if (!srcTy) + return rewriter.notifyMatchFailure(scatterOp, "Expects memref source"); + + Location loc = scatterOp.getLoc(); + auto meta = computeMemrefMeta(scatterOp, rewriter); + if (meta.first.empty()) + return rewriter.notifyMatchFailure(scatterOp, + "Failed to compute strides"); + + Value localOffsets = + computeOffsets(rewriter, scatterOp, meta.first, meta.second); + Value flatMemref = memrefToIndexPtr(scatterOp, rewriter); + + xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(), + flatMemref, localOffsets, scatterOp.getMask(), + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/xegpu::CachePolicyAttr{}, + /*l2_hint=*/xegpu::CachePolicyAttr{}, + /*l3_hint=*/xegpu::CachePolicyAttr{}); + rewriter.eraseOp(scatterOp); + return success(); + } +}; + struct LoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -654,6 +776,8 @@ struct ConvertVectorToXeGPUPass void mlir::populateVectorToXeGPUConversionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns + .add( + patterns.getContext()); } diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir new file mode 100644 index 0000000000000..2a319869a7b06 --- /dev/null +++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir @@ -0,0 +1,251 @@ +// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s + +gpu.module @xevm_module { +gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, + %off1: index, %off2: index, %off3: index, + %indices: vector<8xindex>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask, + %pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32> + gpu.return %0 : vector<8xf32> +} +// CHECK-LABEL: @load_1D_vector( +// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex> +// CHECK-SAME: %[[MASK:.+]]: vector<8xi1> +// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> { +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32> +// CHECK: gpu.return %[[RES]] : vector<8xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @load_2D_memref(%source: memref<8x32xf32>, + %off1: index, %off2: index, + %indices: vector<8xindex>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %0 = vector.gather %source[%off1, %off2][%indices], %mask, + %pass_thru : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32> + gpu.return %0 : vector<8xf32> +} +// CHECK-LABEL: @load_2D_memref( +// CHECK-SAME: %[[SRC:.+]]: memref<8x32xf32>, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex> +// CHECK-SAME: %[[MASK:.+]]: vector<8xi1> +// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> { +// CHECK-COUNT1: arith.muli {{.*}} : index +// CHECK-COUNT1: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index +// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32> +// CHECK: gpu.return %[[RES]] : vector<8xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @load_2D_vector(%source: memref<8x16x32xf32>, + %off0: index, %off1: index, %off2: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>, + %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> { + %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask, + %pass_thru : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> +} +// CHECK-LABEL: @load_2D_vector( +// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex> +// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1> +// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> { +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex> +// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32> +// CHECK: gpu.return %[[RES]] : vector<8x16xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @load_dynamic_source(%source: memref, + %off0: index, %off1: index, %off2: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>, + %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> { + %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask, + %pass_thru : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> +} +// CHECK-LABEL: @load_dynamic_source( +// CHECK-SAME: %[[SRC:.+]]: memref, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex> +// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1> +// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> { +// CHECK: memref.extract_strided_metadata %[[SRC]] +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex> +// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref -> index +// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32> +// CHECK: gpu.return %[[RES]] : vector<8x16xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @load_dynamic_source2(%source: memref, + %off0: index, %off1: index, %off2: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>, + %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> { + %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask, + %pass_thru : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> +} +// CHECK-LABEL: @load_dynamic_source2( +// CHECK-SAME: %[[SRC:.+]]: memref, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex> +// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1> +// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> { +// CHECK-NOT: memref.extract_strided_metadata %[[SRC]] +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex> +// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref -> index +// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32> +// CHECK: gpu.return %[[RES]] : vector<8x16xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @no_load_tensor(%source: tensor<32x64xf32>, + %off: index, %indices: vector<8x16xindex>, + %mask: vector<8x16xi1>, %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> { + %0 = vector.gather %source[%off, %off][%indices], %mask, + %pass_thru : tensor<32x64xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32> + gpu.return %0 : vector<8x16xf32> +} +// CHECK-LABEL: @no_load_tensor( +// CHECK: vector.gather +} + +// ----- +gpu.module @xevm_module { +gpu.func @gather_from_subview(%source: memref<4096x4096xf16>, + %memref_off: index, %off1: index, %off2: index, + %indices: vector<8xindex>, + %mask: vector<8xi1>, + %pass_thru: vector<8xf16>) -> vector<8xf16> { + %subview = memref.subview %source[%memref_off, %memref_off] [256, 256] [1, 1] + : memref<4096x4096xf16> + to memref<256x256xf16, strided<[4096, 1], offset: ?>> + %0 = vector.gather %subview[%off1, %off2][%indices], %mask, %pass_thru + : memref<256x256xf16, strided<[4096, 1], offset: ?>>, + vector<8xindex>, vector<8xi1>, vector<8xf16> + into vector<8xf16> + gpu.return %0 : vector<8xf16> +} +// CHECK-LABEL: @gather_from_subview( +// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// CHECK-SAME: %[[MEMREF_OFF:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, +// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>, +// CHECK-SAME: %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> { +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF]], %[[MEMREF_OFF]]] [256, 256] [1, 1] +// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref, index, index, index, index, index +// CHECK: arith.muli {{.*}}%[[OFF1]]{{.*}} : index +// CHECK: arith.addi %[[OFFSET]]{{.*}} : index +// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex> +// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64 +// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]] +// CHECK-SAME: : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16> +// CHECK: gpu.return %[[RES]] : vector<8xf16> +} + +// ----- +gpu.module @xevm_module { +gpu.func @non_unit_inner_stride_1D( + %source: memref<32xf32, strided<[?], offset: ?>>, + %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %0 = vector.gather %source[%off][%indices], %mask, %pass_thru + : memref<32xf32, strided<[?], offset: ?>>, + vector<8xindex>, vector<8xi1>, vector<8xf32> + into vector<8xf32> + gpu.return %0 : vector<8xf32> +} +// CHECK-LABEL: @non_unit_inner_stride_1D( +// CHECK-SAME: %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>, +// CHECK-SAME: %[[OFF1:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, +// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>, %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> { +// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]] +// CHECK: arith.muli %[[OFF1]], %[[STRIDE]] : index +// CHECK: arith.addi {{.*}} : index +// CHECK: %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex> +// CHECK: %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32> +// CHECK: gpu.return %[[RES]] : vector<8xf32> +} + +// ----- +gpu.module @xevm_module { +gpu.func @non_unit_inner_stride_3D( + %source: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>, + %off0: index, %off1: index, %off2: index, + %indices: vector<8xindex>, %mask: vector<8xi1>, + %pass_thru: vector<8xf32>) -> vector<8xf32> { + %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask, %pass_thru + : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>, + vector<8xindex>, vector<8xi1>, vector<8xf32> + into vector<8xf32> + gpu.return %0 : vector<8xf32> +} +// CHECK-LABEL: @non_unit_inner_stride_3D( +// CHECK-SAME: %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>, +// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>, +// CHECK-SAME: %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> { +// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]] +// CHECK: arith.muli %[[OFF0]], %[[STRIDES]]#0 : index +// CHECK: arith.addi {{.*}} : index +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex> +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> +// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32> +// CHECK: gpu.return %[[RES]] : vector<8xf32> +} diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir new file mode 100644 index 0000000000000..ffd3f170c0fad --- /dev/null +++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir @@ -0,0 +1,206 @@ +// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s + +gpu.module @xevm_module { +gpu.func @store_1D_vector(%vec: vector<8xf32>, %source: memref<8x16x32xf32>, + %off1: index, %off2: index, %off3: index, + %indices: vector<8xindex>, %mask: vector<8xi1>) { + vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec + : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> + gpu.return +} +// CHECK-LABEL: @store_1D_vector( +// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) { +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_2D_memref(%vec: vector<8xf32>, %source: memref<8x32xf32>, + %off1: index, %off2: index, + %indices: vector<8xindex>, %mask: vector<8xi1>) { + vector.scatter %source[%off1, %off2][%indices], %mask, %vec + : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> + gpu.return +} +// CHECK-LABEL: @store_2D_memref( +// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x32xf32>, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) { +// CHECK-COUNT1: arith.muli {{.*}} : index +// CHECK-COUNT1: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_2D_vector(%vec: vector<8x16xf32>, %source: memref<8x16x32xf32>, + %off0: index, %off1: index, %off2: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) { + vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec + : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_2D_vector( +// CHECK-SAME: %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) { +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, %source: memref, + %off0: index, %off1: index, %off2: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) { + vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec + : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_dynamic_source( +// CHECK-SAME: %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) { +// CHECK: memref.extract_strided_metadata %[[SRC]] +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @store_dynamic_source2(%vec: vector<8x16xf32>, %source: memref, + %off0: index, %off1: index, %off2: index, + %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) { + vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec + : memref, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> + gpu.return +} +// CHECK-LABEL: @store_dynamic_source2( +// CHECK-SAME: %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref, +// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) { +// CHECK-NOT: memref.extract_strided_metadata %[[SRC]] +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @non_unit_inner_stride_1D( + %vec: vector<8xf32>, %source: memref<32xf32, strided<[?], offset: ?>>, + %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>) { + vector.scatter %source[%off][%indices], %mask, %vec + : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32> + gpu.return +} +// CHECK-LABEL: @non_unit_inner_stride_1D( +// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>, +// CHECK-SAME: %[[OFF1:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) { +// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]] +// CHECK: arith.muli %[[OFF1]], %[[STRIDE]] : index +// CHECK: arith.addi {{.*}} : index +// CHECK: %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex> +// CHECK: %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @non_unit_inner_stride_3D( + %vec: vector<8xf32>, + %source: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>, + %off0: index, %off1: index, %off2: index, + %indices: vector<8xindex>, %mask: vector<8xi1>) { + vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec + : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>, + vector<8xindex>, vector<8xi1>, vector<8xf32> + gpu.return +} +// CHECK-LABEL: @non_unit_inner_stride_3D( +// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>, +// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) { +// CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]] +// CHECK: arith.muli %[[OFF0]], %[[STRIDES]]#0 : index +// CHECK: arith.addi {{.*}} : index +// CHECK-COUNT2: arith.muli {{.*}} : index +// CHECK-COUNT2: arith.addi {{.*}} : index +// CHECK: %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex> +// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex> +// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex> +// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64 +// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1> +// CHECK: gpu.return +} + +// ----- +gpu.module @xevm_module { +gpu.func @scatter_into_subview(%vals: vector<8xf16>, + %source: memref<4096x4096xf16>, + %memref_off: index, %off1: index, %off2: index, + %indices: vector<8xindex>, + %mask: vector<8xi1>) { + %subview = memref.subview %source[%memref_off, %memref_off] [256, 256] [1, 1] + : memref<4096x4096xf16> + to memref<256x256xf16, strided<[4096, 1], offset: ?>> + vector.scatter %subview[%off1, %off2][%indices], %mask, %vals + : memref<256x256xf16, strided<[4096, 1], offset: ?>>, + vector<8xindex>, vector<8xi1>, vector<8xf16> + gpu.return +} +// CHECK-LABEL: @scatter_into_subview( +// CHECK-SAME: %[[VALS:.+]]: vector<8xf16>, +// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// CHECK-SAME: %[[MEMREF_OFF:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, +// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) { +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF]], %[[MEMREF_OFF]]] [256, 256] [1, 1] +// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref, index, index, index, index, index +// CHECK: arith.muli {{.*}}%[[OFF1]]{{.*}} : index +// CHECK: arith.addi %[[OFFSET]]{{.*}} : index +// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index +// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex> +// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex> +// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64 +// CHECK: xegpu.store %[[VALS]], %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1> +// CHECK: gpu.return +}