diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt index 0fe01824b8248..bbe8e4eb892dd 100644 --- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector) add_mlir_interface(X86VectorInterfaces) add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..6f377e10fa8f8 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td) +mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs) +add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h new file mode 100644 index 0000000000000..e1d8b8762e799 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h @@ -0,0 +1,31 @@ +//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// X86Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace x86vector { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace x86vector +} // namespace mlir + +#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td new file mode 100644 index 0000000000000..3c5294ff14fc7 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -0,0 +1,43 @@ +//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_TRANSFORM_OPS +#define X86VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/RegionKindInterface.td" + +def ApplyVectorContractToFMAPatternsOp : Op]> { + let description = [{ + Collect patterns to lower a F32 type vector.contract operation to a FMA. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op]> { + let description = [{ + Collect patterns to lower a BF16/Int8 type vector.contract operation + to a BF16/Int8 dot-product. + }]; + + let assemblyFormat = "attr-dict"; +} + + +#endif // X86VECTOR_TRANSFORM_OPS + diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index d54111ca41e69..fc46dff63c2b7 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -79,6 +79,18 @@ struct MaskHelper { } }; +//===----------------------------------------------------------------------===// + +// A set of patterns for specialized lowering of vector contraction +// operation to vector fused multiply and add (FMA) operation. +void populateVectorContractToFMAPatterns(RewritePatternSet &patterns); + +// A set of patterns for lowering 32-bit packed vector contraction operations +// to their corresponding packed-type dot-product operations, ultimately +// targeting the relevant x86 LLVM intrinsics (e.g., BF16 and Int8). +void populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// /// Helpers extracted from: /// - clang/lib/Headers/avxintrin.h diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..f4c9f8a05acbc --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRX86VectorDialect + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000000000..95db208207672 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,64 @@ +//===- X86VectorTransformOps.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + +void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + x86vector::populateVectorContractToFMAPatterns(patterns); +} + +void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266afe9e8f..3d2288049e49e 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + VectorContractToFMA.cpp + VectorContractToPackedTypeDotProduct.cpp LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp new file mode 100644 index 0000000000000..f3af5ca167a35 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp @@ -0,0 +1,143 @@ +//===- VectorContractToFMA.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +// Implements outer product contraction as a sequence of broadcast and +// FMA operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <16xf32> +// vector.fma vector<16xf32> +// ``` +struct VectorContractToFMA : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 lowering is supported."); + + ArrayRef lhsShape = lhsTy.getShape(); + llvm::SmallVector nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef rhsShape = rhsTy.getShape(); + llvm::SmallVector nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0) + return rewriter.notifyMatchFailure( + contractOp, "Excepts unit dimensions for either LHS or RHS shape."); + + if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator is not a vector type"); + + if (!accTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator should be F32 type."); + + ArrayRef accShape = accTy.getShape(); + llvm::SmallVector nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B dimension should be non-unit."); + + // Lowers vector.contract into a broadcast+FMA sequence. + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + vector::FMAOp fma; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we + // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if (nonUnitDimRhs.size() > 0) { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, castRhs.getResult().getType(), castLhs); + fma = + vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc); + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, castLhs.getResult().getType(), castRhs); + fma = + vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc); + } + + auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma); + rewriter.replaceOp(contractOp, castFma); + + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToFMAPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp new file mode 100644 index 0000000000000..1e64811db910b --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -0,0 +1,301 @@ +//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +static FailureOr> +inferIteratorsFromOutMap(AffineMap map) { + if (!map.isProjectedPermutation()) + return failure(); + SmallVector iterators( + map.getNumDims(), mlir::utils::IteratorType::reduction); + for (auto expr : map.getResults()) + if (auto dim = dyn_cast(expr)) + iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel; + return iterators; +} + +// Returns true if the operation is in VNNI layout. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +static bool isInVnniLayout(Operation *op, ArrayRef indexingMaps, + std::optional blockingFactor) { + // Narrow down type operations - VNNI only applies to contractions. + FailureOr dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return false; + + auto matA = op->getOperand(0); + auto matB = op->getOperand(1); + auto typeA = dyn_cast(matA.getType()); + auto typeB = dyn_cast(matB.getType()); + unsigned rankA = typeA.getRank(); + unsigned rankB = typeB.getRank(); + // VNNI format requires at least 1 parallel and 2 reduction dimensions. + if (rankA < 3 || rankB < 3) + return false; + + // At least two reduction dimensions are expected: + // one for the VNNI factor and one for the K dimension + if (dims->k.size() < 2) + return false; + + // Validate affine maps - VNNI computation should be defined by the two + // innermost reduction iterators. + // The input matrix dimensions layout must match the following: + // - matrix A - [...][K/vnniFactor][vnniFactor] + // - matrix B - [...][K/vnniFactor][N][vnniFactor] + auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]); + if (failed(maybeIters)) + return false; + SmallVector iteratorTypes = *maybeIters; + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + + auto vnniDimA = dyn_cast(mapA.getResult(rankA - 1)); + auto vnniDimB = dyn_cast(mapB.getResult(rankB - 1)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto redDimA = dyn_cast(mapA.getResult(rankA - 2)); + auto redDimB = dyn_cast(mapB.getResult(rankB - 3)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto parallelDimB = dyn_cast(mapB.getResult(rankB - 2)); + if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] != + mlir::utils::IteratorType::parallel) + return false; + + // VNNI factor must be: + // - the innermost inputs' dimension + // - statically known + // - multiple of 2 or equal to the specified factor + auto vnniDimSize = typeB.getShape().back(); + if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || + vnniDimSize % 2 != 0) + return false; + if (typeA.getShape().back() != vnniDimSize) + return false; + if (blockingFactor && vnniDimSize != *blockingFactor) + return false; + + // The split reduction dimension size should also match. + if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3]) + return false; + + return true; +} + +// Implements packed type outer product contraction as a sequence +// of broadcast and packed dot-product operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <32xbf16> +// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32> +// ``` +struct VectorContractToPackedTypeDotProduct + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isBF16() && + !lhsTy.getElementType().isSignlessInteger(8)) + return rewriter.notifyMatchFailure( + contractOp, "Only BF16/Int8 lowering is supported."); + + unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4; + if (!isInVnniLayout(contractOp.getOperation(), + contractOp.getIndexingMapsArray(), blockingFactor)) + return rewriter.notifyMatchFailure(contractOp, + "Input matrices not in VNNI format."); + + ArrayRef lhsShape = lhsTy.getShape(); + llvm::SmallVector nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef rhsShape = rhsTy.getShape(); + llvm::SmallVector nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0) + return rewriter.notifyMatchFailure(contractOp, + "Excepts unit dimensions for either " + "LHS or RHS shape other than VNNI."); + + if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type."); + + if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) || + (lhsTy.getElementType().isSignlessInteger(8) && + !accTy.getElementType().isSignlessInteger(32))) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 for BF16 or Int32 for Int8 " + "accumulation type is supported."); + + ArrayRef accShape = accTy.getShape(); + llvm::SmallVector nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B should be a non-unit dim in acc."); + + // Non-unit dimensions should match the vector length of BF16 or Int8 + // dot-product. + unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front() + : nonUnitDimRhs.front(); + if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 && + nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "BF16 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8/16."); + + if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 && + nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "Int8 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8."); + + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + Value dp; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>, + // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if ((nonUnitDimRhs.size() - 1) > 0) { + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(), + rhsTy.getElementType()), + contractOp.getRhs()); + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto bitcastLhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castLhs); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)), + bitcastLhs); + auto bitcastLhsPkType = vector::BitCastOp::create( + rewriter, loc, castRhs.getResult().getType(), broadcastLhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()), + castAcc, bitcastLhsPkType, castRhs); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)), + castAcc, bitcastLhsPkType, castRhs); + } + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(), + lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto bitcastRhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castRhs); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)), + bitcastRhs); + auto bitcastRhsPkType = vector::BitCastOp::create( + rewriter, loc, castLhs.getResult().getType(), broadcastRhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()), + castAcc, castLhs, bitcastRhsPkType); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)), + castAcc, castLhs, bitcastRhsPkType); + } + } + + if (!dp) + return failure(); + + auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp); + rewriter.replaceOp(contractOp, castDp); + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index c857c38df717c..4312100a0c0b0 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + x86vector::registerTransformDialectExtension(registry); xegpu::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir new file mode 100644 index 0000000000000..e506b166d43ff --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir @@ -0,0 +1,344 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma +// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32> +// CHECK: vector.fma{{.*}}vector<64xf32> +// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<64x1xf32> +!vecB = vector<1x1xf32> +!vecC = vector<64x1xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x64x1xf32> +!vecB = vector<1x1x1xf32> +!vecC = vector<1x64x1xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x64x1xf32> +!vecB = vector<1x1x1xf32> +!vecC = vector<64x1xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<3x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_non_unit_batch_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_non_unit_batch_reduce_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch-reduce dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_reduce_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @negative_invalid_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_kind +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xi32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_accumulator_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_accumulator_type +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir new file mode 100644 index 0000000000000..65676cbae772c --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir @@ -0,0 +1,681 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x16x1x2xbf16> +!vecB = vector<1x1x1x2xbf16> +!vecC = vector<16x1xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x8x1x4xi8> +!vecB = vector<1x1x1x4xi8> +!vecC = vector<1x8x1xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<16x1x2xbf16> +!vecB = vector<1x1x2xbf16> +!vecC = vector<16x1xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x4xi8> +!vecB = vector<1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_invalid_vc_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_vc_kind +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xi8> +!vecB = vector<1x1x8x2xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1x2xbf16> +!vecB = vector<3x1x16x2xbf16> +!vecC = vector<3x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_batch_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_batch_dimension +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<2x1x1x4xi8> +!vecB = vector<2x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_dimension +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xbf16> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_float_acc_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_float_acc_type +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi8> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_int_acc_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_int_acc_type +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x1x16xbf16> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vnni_blocking_factor_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xbf16> +!vecB = vector<1x1x32xbf16> +!vecC = vector<1x32xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_not_vnni( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_not_vnni +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x16x4xi8> +!vecC = vector<1x1x16xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vector_shape_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vector_shape_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x32x2xbf16> +!vecC = vector<1x1x32xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vector_shape_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vector_shape_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +}