diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp index e3037186569b8..b47eede2a9156 100644 --- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp +++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -67,7 +67,15 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter, mlir::vector::ContractionOp contractB, int64_t nonUnitDimAcc, mlir::VectorType Ty) { - mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA; + + bool opABeforeopB = opA->isBeforeInBlock(opB); + + if (opABeforeopB) + rewriter.moveOpAfter(opB, opA); + else + rewriter.moveOpAfter(opA, opB); + + mlir::Operation *insertAfter = opABeforeopB ? opB : opA; rewriter.setInsertionPointAfter(insertAfter); mlir::Location loc = insertAfter->getLoc(); @@ -326,14 +334,6 @@ struct VectorContractToPackedTypeDotProduct return rewriter.notifyMatchFailure( contractOp, "Could not find a valid contract pair"); - if (contractOp->getBlock() == - nonUnitDimReadOpPairContract->getBlock() && - contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract)) - return rewriter.notifyMatchFailure( - contractOp, - "The load/read operation of pair contract operation is " - "after the contractOp"); - VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getRhsType() : contractOp.getLhsType(); diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp index 805d9c5c00b63..3893d8d288f32 100644 --- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp +++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp @@ -341,6 +341,9 @@ bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue) { + if (contractOp == pairContOp) + return false; + if (rhsHasMultipleNonUnitDims && !(contractOp.getLhs() == pairContOp.getLhs())) return false; @@ -393,21 +396,25 @@ bool validatePairVectorContract(vector::ContractionOp contractOp, if (srcBuff != srcBuffPairContOp) return false; + bool oneConstantOffset = false; for (size_t i = 0; i < indexVals.size(); i++) { + + if (indexVals[i] == indexValsPairContOp[i]) + continue; + auto v0 = getConstantIntValue(indexVals[i]); auto v1 = getConstantIntValue(indexValsPairContOp[i]); if (!v0 || !v1) return false; - if (*v1 == *v0) - continue; - if ((*v1 - *v0) != nonUnitDimValue) return false; + + oneConstantOffset = true; } - return true; + return oneConstantOffset; } } // namespace x86 diff --git a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir index eabf15c0af303..0953ee042a24d 100644 --- a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir +++ b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir @@ -412,6 +412,76 @@ module attributes {transform.with_named_sequence} { // ----- +!vecA = vector<1x2xbf16> +!vecB = vector<2x16xbf16> +!vecC = vector<1x16xf32> +!memrefA = memref<4x2xbf16> +!memrefB = memref<2x32xbf16> +!memrefC = memref<2x32xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_bf16dp_flat_layout_offset_args_and_read_after_vc( + %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC +{ + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %0 = ub.poison : bf16 + %32 = ub.poison : f32 + %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} : + !memrefA, !vecA + %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} : + !memrefC, !vecC + %4 = vector.transfer_read %arg1[%arg3, %c0], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + + %5 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %4, %2 + : !vecA, !vecB into !vecC + + %6 = vector.transfer_read %arg1[%arg3, %c16], %0 {in_bounds = [true, true]} : + !memrefB, !vecB + + %7 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %1, %6, %3 + : !vecA, !vecB into !vecC + + vector.transfer_write %5, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC + vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC + + return %arg2 : !memrefC +} + +// CHECK-LABEL: @matmul_bf16dp_flat_layout_offset_args_and_read_after_vc +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> +// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16> +// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16> +// CHECK: x86.avx512.dot +// CHECK: x86.avx512.dot +// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> +// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + !vecA = vector<1x1x2xbf16> !vecB = vector<1x2x16xbf16> !vecC = vector<1x16xf32>