diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp index a4496f3620b97..4d0c82085f8a8 100644 --- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp +++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -82,22 +82,78 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter, auto elemTy = Ty.getElementType(); auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy); - auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, - opA->getResult(0)); - auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, - opB->getResult(0)); + Value srcBuff; + SmallVector indexVals; - static constexpr int64_t maskLo[] = { + llvm::TypeSwitch(opA).Case( + [&](auto readOp) { + srcBuff = readOp.getOperand(0); + + auto indices = readOp.getIndices(); + indexVals.reserve(indices.size()); + + llvm::transform( + indices, std::back_inserter(indexVals), [&](OpFoldResult ofr) { + return mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr); + }); + }); + + auto vec1 = vector::LoadOp::create(rewriter, loc, flatTy, srcBuff, indexVals); + + unsigned int offset = 1; + if (elemTy.isSignlessInteger(8)) + offset = 2; + + Value cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset); + auto nextIndx = + arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), cOffset, + indexVals[indexVals.size() - 2]); + indexVals[indexVals.size() - 2] = nextIndx; + + auto vec2 = vector::LoadOp::create(rewriter, loc, flatTy, srcBuff, indexVals); + + static constexpr int64_t maskLo_bf16[] = { 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59}; - static constexpr int64_t maskHi[] = { + static constexpr int64_t maskHi_bf16[] = { 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63}; - auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA, - castB, maskLo); - auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA, - castB, maskHi); + static constexpr int64_t maskLo_int8_avx2[] = { + 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, + 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59}; + static constexpr int64_t maskHi_int8_avx2[] = { + 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, + 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63}; + + static constexpr int64_t maskLo_int8_avx10[] = { + 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, + 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, + 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, + 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123}; + static constexpr int64_t maskHi_int8_avx10[] = { + 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, + 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, + 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, + 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127}; + + mlir::DenseI64ArrayAttr maskLo = rewriter.getDenseI64ArrayAttr(maskLo_bf16); + mlir::DenseI64ArrayAttr maskHi = rewriter.getDenseI64ArrayAttr(maskHi_bf16); + + if (elemTy.isSignlessInteger(8)) { + maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx10); + maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx10); + + if (nonUnitDimAcc == 32) { + maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx2); + maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx2); + } + } + + auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1, + vec2, maskLo); + auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1, + vec2, maskHi); auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo); auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi); @@ -159,9 +215,6 @@ struct VectorContractToPackedTypeDotProduct isInVnniLayout(contractOp.getOperation(), contractOp.getIndexingMapsArray(), blockingFactor); - if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni) - return failure(); - VectorType accTy = dyn_cast(contractOp.getAccType()); if (!accTy) return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type."); @@ -217,7 +270,7 @@ struct VectorContractToPackedTypeDotProduct if (!isVnni && (extraFlatDim != blockingFactor)) return rewriter.notifyMatchFailure( - contractOp, "The K or reduction dim for flat layout should be 2."); + contractOp, "The K or reduction dim for flat layout should be 2/4."); if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) || (lhsTy.getElementType().isSignlessInteger(8) && diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp index f4279a3eb507a..aea6bf6adcd4a 100644 --- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp +++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp @@ -116,15 +116,20 @@ struct ShuffleMasks { llvm::ArrayRef maskHi; }; -inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) { +inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc, bool isInt8Avx2) { // We only support these two layouts for now. assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) && "Unsupported nonUnitDimAcc value"); + // Do interleaving between two <8xf32> targeting AVX2. static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11}; static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15}; - // Shuffle two <16xf32> as below targeting AVX512. + // Do interleaving between two <8xi32> targeting AVX2. + static constexpr int64_t maskLo8_avx2_int8[] = {0, 1, 2, 3, 8, 9, 10, 11}; + static constexpr int64_t maskHi8_avx2_int8[] = {4, 5, 6, 7, 12, 13, 14, 15}; + + // Shuffle two <16xf32/i32> as below targeting AVX512. static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27, @@ -133,6 +138,9 @@ inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) { if (nonUnitDimAcc == 16) return {maskLo16, maskHi16}; + if (isInt8Avx2) + return {maskLo8_avx2_int8, maskHi8_avx2_int8}; + return {maskLo8, maskHi8}; } @@ -255,7 +263,8 @@ LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0)); - auto masks = getShuffleMasks(nonUnitDimAcc); + auto masks = getShuffleMasks( + nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8)); auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA, castB, masks.maskLo); @@ -313,7 +322,8 @@ LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB); // TODO: derive shuffle masks instead of hard-coding - auto masks = getShuffleMasks(nonUnitDimAcc); + auto masks = getShuffleMasks( + nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8)); auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA, castB, masks.maskLo); diff --git a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir index 0953ee042a24d..f861d357739a3 100644 --- a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir +++ b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir @@ -412,6 +412,144 @@ module attributes {transform.with_named_sequence} { // ----- +!vecA = vector<1x4xi8> +!vecB = vector<4x16xi8> +!vecC = vector<1x16xi32> +!memrefA = memref<4x4xi8> +!memrefB = memref<4x64xi8> +!memrefC = memref<4x64xi32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_i8_avx10dp_flat_layout( + %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC +{ + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %0 = ub.poison : i8 + %32 = ub.poison : i32 + %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefA, !vecA + %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + + %6 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %4, %2 + : !vecA, !vecB into !vecC + + %7 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %5, %3 + : !vecA, !vecB into !vecC + + vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC + + return %arg2 : !memrefC +} + +// CHECK-LABEL: @matmul_i8_avx10dp_flat_layout +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32> +// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32> +// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8> +// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8> +// CHECK: x86.avx10.dot.i8 +// CHECK: x86.avx10.dot.i8 +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32> +// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x4xi8> +!vecB = vector<4x8xi8> +!vecC = vector<1x8xi32> +!memrefA = memref<4x4xi8> +!memrefB = memref<4x64xi8> +!memrefC = memref<4x64xi32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_i8_avx2dp_flat_layout( + %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC +{ + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %0 = ub.poison : i8 + %32 = ub.poison : i32 + %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefA, !vecA + %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %3 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + %5 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + + %6 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %4, %2 + : !vecA, !vecB into !vecC + + %7 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %5, %3 + : !vecA, !vecB into !vecC + + vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC + + return %arg2 : !memrefC +} + +// CHECK-LABEL: @matmul_i8_avx2dp_flat_layout +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 8, 9, 10, 11] : vector<8xi32>, vector<8xi32> +// CHECK-NEXT: vector.shuffle{{.*}}[4, 5, 6, 7, 12, 13, 14, 15] : vector<8xi32>, vector<8xi32> +// CHECK: vector.shuffle{{.*}}[0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59] : vector<32xi8>, vector<32xi8> +// CHECK-NEXT: vector.shuffle{{.*}}[4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63] : vector<32xi8>, vector<32xi8> +// CHECK: x86.avx.dot.i8 +// CHECK: x86.avx.dot.i8 +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 8, 9, 10, 11] : vector<8xi32>, vector<8xi32> +// CHECK-NEXT: vector.shuffle{{.*}}[4, 5, 6, 7, 12, 13, 14, 15] : vector<8xi32>, vector<8xi32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + !vecA = vector<1x2xbf16> !vecB = vector<2x16xbf16> !vecC = vector<1x16xf32> @@ -640,6 +778,102 @@ func.func @matmul_bf16dp_flat_layout_B_shuffled( // CHECK: x86.avx512.dot // CHECK-NOT: vector.contract +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x4xi8> +!vecB = vector<1x4x16xi8> +!vecC = vector<1x16xi32> +!memrefA = memref<1x2x4xi8, strided<[16384, 256, 1], offset: ?>> +!memrefB = memref<1x4x32xi8, strided<[32768, 128, 1], offset: ?>> +!memrefC = memref<2x32xi32, strided<[128, 1], offset: ?>> + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +func.func @brgemm_int8_flat_avx10(%arg0: memref<16x64x256xi8>, %arg1: memref<16x256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> { + %0 = ub.poison : i32 + %1 = ub.poison : i8 + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %arg3 = %c0 to %c64 step %c2 { + scf.for %arg4 = %c0 to %c128 step %c32 { + %subview = memref.subview %arg2[%arg3, %arg4] [2, 32] [1, 1] : memref<64x128xi32> to !memrefC + %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} + : !memrefC, !vecC + %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} + : !memrefC, !vecC + %4 = vector.transfer_read %subview[%c1, %c0], %0 {in_bounds = [true, true]} + : !memrefC, !vecC + %5 = vector.transfer_read %subview[%c1, %c16], %0 {in_bounds = [true, true]} + : !memrefC, !vecC + %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) { + %7:4 = scf.for %arg10 = %c0 to %c256 step %c4 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) { + %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10] [1, 2, 4] [1, 1, 1] : memref<16x64x256xi8> to !memrefA + %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4] [1, 4, 32] [1, 1, 1] : memref<16x256x128xi8> to !memrefB + %8 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} + : !memrefA, !vecA + %9 = vector.transfer_read %subview_0[%c0, %c1, %c0], %1 {in_bounds = [true, true, true]} + : !memrefA, !vecA + %10 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} + : !memrefB, !vecB + %11 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]} + : !memrefB, !vecB + %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = + ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} + %8, %10, %arg11 {unroll_shape = array} : !vecA, !vecB into !vecC + %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = + ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} + %8, %11, %arg12 {unroll_shape = array} : !vecA, !vecB into !vecC + %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = + ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} + %9, %10, %arg13 {unroll_shape = array} : !vecA, !vecB into !vecC + %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = + ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} + %9, %11, %arg14 {unroll_shape = array} : !vecA, !vecB into !vecC + scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC + } + scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC + } + vector.transfer_write %6#3, %subview[%c1, %c16] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %6#2, %subview[%c1, %c0] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC + } + } + %alloc = memref.alloc() : memref<64x128xi32> + memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32> + return %alloc : memref<64x128xi32> +} + +// CHECK-LABEL: @brgemm_int8_flat_avx10 +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32> +// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32> +// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8> +// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8> +// CHECK: x86.avx10.dot.i8 +// CHECK: x86.avx10.dot.i8 +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32> +// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32> + + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op @@ -1548,3 +1782,66 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +!vecA = vector<1x4xi8> +!vecB = vector<4x8xi8> +!vecC = vector<1x8xi32> +!memrefA = memref<4x4xi8> +!memrefB = memref<4x64xi8> +!memrefC = memref<4x64xi32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @negative_i8_avx2dp_flat_layout_offset_diff_16( + %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC +{ + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %0 = ub.poison : i8 + %32 = ub.poison : i32 + %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefA, !vecA + %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + + %6 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %4, %2 + : !vecA, !vecB into !vecC + + %7 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %5, %3 + : !vecA, !vecB into !vecC + + vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC + + return %arg2 : !memrefC +} + +// CHECK-LABEL: @negative_i8_avx2dp_flat_layout_offset_diff_16 +// CHECK-NOT: x86.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +}