diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt index 0fe01824b8248..bbe8e4eb892dd 100644 --- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector) add_mlir_interface(X86VectorInterfaces) add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..6f377e10fa8f8 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td) +mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs) +add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h new file mode 100644 index 0000000000000..e1d8b8762e799 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h @@ -0,0 +1,31 @@ +//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// X86Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace x86vector { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace x86vector +} // namespace mlir + +#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td new file mode 100644 index 0000000000000..9db2b36a2a8aa --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -0,0 +1,37 @@ +//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_TRANSFORM_OPS +#define X86VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/RegionKindInterface.td" + +def ApplyVectorContractNanokernelLoweringPatternsOp : Op]> { + let description = [{ + Indicates that vector contract operation can be lowered to target + specific nanokernels. + }]; + + let arguments = (ins DefaultValuedAttr:$vector_size); + + let assemblyFormat = [{ + (`vector_size` `=` $vector_size^)? attr-dict + }]; + +} + + +#endif // X86VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index d54111ca41e69..6410c12265f12 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -11,6 +11,10 @@ #include "mlir/IR/Value.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" + namespace mlir { class ImplicitLocOpBuilder; @@ -79,6 +83,14 @@ struct MaskHelper { } }; +//===----------------------------------------------------------------------===// +// Transforms a scheduled pattern to lower a tiled batch or batch-reduce +// vector contraction into a sequence of nanokernels. +// The transformation is tailored to the target machine architecture +// and guided by the user-specified vector size. +void populateVectorContractNanokernelLoweringPatterns( + RewritePatternSet &patterns, std::optional vectorSize = 8); + //===----------------------------------------------------------------------===// /// Helpers extracted from: /// - clang/lib/Headers/avxintrin.h diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..f4c9f8a05acbc --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRX86VectorDialect + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000000000..e003e3ad7cd08 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,61 @@ +//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops +//--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + +void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractNanokernelLoweringPatterns(patterns, + getVectorSize()); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266afe9e8f..da377763331f2 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + NanoKernels.cpp LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp new file mode 100644 index 0000000000000..4d0906a2ec057 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp @@ -0,0 +1,659 @@ +//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements matmul rewrites as nanokernels with respect to target +// machine for FP32 (for selective batch or batch-reduce matmul patterns) and +// BF16 (TODO) types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +// Enum to represent the type of matmul operation +enum class MatMulType { Batch, BatchReduce, Others }; + +static FailureOr> +getTiledMatmulLoopNest(vector::ContractionOp contractOp, + MatMulType matmulType) { + SmallVector list; + Operation *current = contractOp; + unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3; + + // It is register tiled loop structure on batch (or reduce) matmul + // (M->N->(reduce)->K). + for (unsigned int i = 0; i < dimCount; i++) { + Operation *parent = current->getParentOfType(); + if (!parent) + return failure(); + list.push_back(dyn_cast(parent)); + current = parent; + } + return list; +} + +static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching( + SmallVector loops, SmallVector subviews, + MatMulType matmulType) { + auto subviewOpLhsOffsets = subviews[0].getOffsets(); + auto subviewOpRhsOffsets = subviews[1].getOffsets(); + auto subviewOpAccOffsets = subviews[2].getOffsets(); + + if (matmulType == MatMulType::BatchReduce) { + Value ivK = loops[0].getInductionVar(); + if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1]) + return failure(); + + Value ivReduction = loops[1].getInductionVar(); + if (ivReduction != subviewOpLhsOffsets[0] || + ivReduction != subviewOpRhsOffsets[0]) + return failure(); + + Value ivN = loops[2].getInductionVar(); + if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2]) + return failure(); + + Value ivM = loops[3].getInductionVar(); + if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0]) + return failure(); + } + + if (matmulType == MatMulType::Batch) { + Value ivK = loops[0].getInductionVar(); + if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0]) + return failure(); + + Value ivN = loops[1].getInductionVar(); + if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1]) + return failure(); + + Value ivM = loops[2].getInductionVar(); + if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0]) + return failure(); + } + + return success(); +} + +static SmallVector +loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter, + Type elementType, unsigned int M, unsigned int N, + unsigned int vectorSize, Value subviewOpAcc) { + + SmallVector accumulators; + + // Initialize local variable on assumption that M tile is larger than N + unsigned int outerBound = M; + unsigned int innerBound = N; + + unsigned int outerStep = 1; + unsigned int innerStep = vectorSize; + + bool isNTileLarge = (N / vectorSize) > M; + if (isNTileLarge) { + outerBound = N; + innerBound = M; + + outerStep = vectorSize; + innerStep = 1; + } + + for (unsigned int i = 0; i < outerBound; i = i + outerStep) { + for (unsigned int j = 0; j < innerBound; j = j + innerStep) { + Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i); + Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j); + + if (isNTileLarge) { + indexOp_A = indexOp_B; + indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i); + } + + auto valueCRow = vector::LoadOp::create( + rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc, + ValueRange{indexOp_A, indexOp_B}); + accumulators.push_back(valueCRow); + } + } + + return accumulators; +} + +// This function takes matrices A, B, and C (represented as vectors) +// and generates equivalent target-specific nanokernels. +// It returns the final accumulator as output. +// Based on the M tile, N tile, and vector size, it generates optimized +// nanokernels under the condition that the reduction and K dimension +// of the input matrices are equal to 1. +// +// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile +// size, N tile size, Vector size. +// +// Output: +// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16. +// load_B0 = load B[0-15] into vector<16xf32> +// load_B1 = load B[16-31] into vector<16xf32> +// bcst_A0 = load A[0] and broadcast it into vector<16xf32> +// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0] +// o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1] +// bcst_A1 = load A[1] and broadcast it into vector<16xf32> +// o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2] +// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3] +// bcst_A2 = load A[2] and broadcast it into vector<16xf32> +// o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4] +// o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5] +// +// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16. +// bcst_A0 = load A[0] and broadcast it into vector<16xf32> +// bcst_A1 = load A[1] and broadcast it into vector<16xf32> +// bcst_A2 = load A[2] and broadcast it into vector<16xf32> +// load_B0 = load B[0-15] into vector<16xf32> +// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0] +// o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1] +// load_B1 = load B[16-31] into vector<16xf32> +// o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2] +// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3] +// load_B2 = load B[32-47] into vector<16xf32> +// o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4] +// o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5] +// +// return o/p_Acc; +SmallVector +generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, + unsigned int vectorSize, unsigned int vnni, unsigned int M, + unsigned int N, ValueRange acc, Value matA, Value matB, + MatMulType matmulType) { + + SmallVector accumulators; + SmallVector matLoad; + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + + // Start with assumption that M tile size is smaller and create the + // helper variables + unsigned int outerBound = M; + unsigned int outerStep = 1; + + unsigned int innerBound = N; + unsigned int innerStep = vectorSize; + + Value outerMatrix = matA; + Value innerMatrix = matB; + + unsigned int outerVectSize = vnni; + unsigned int innerVectSize = vectorSize; + + unsigned int fmaBound = M; + + // update helper variables if N tile size is smaller + bool isNTileLarge = (N / vectorSize) > M; + if (!isNTileLarge) { + outerBound = N; + innerBound = M; + + outerStep = vectorSize; + innerStep = 1; + + outerMatrix = matB; + innerMatrix = matA; + + outerVectSize = vectorSize; + innerVectSize = vnni; + + fmaBound = N / vectorSize; + } + + // Load all the element of A or B matrix + for (unsigned int i = 0; i < outerBound; i = i + outerStep) { + Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i); + Value valueRow; + + if (isNTileLarge) { + + // With the assumption as batch-reduce matmul initialize reduction, M, and + // K dimension. + SmallVector index = {c0, indexOp_i, c0}; + + // Remove reduction dimension if it is a batch matmul + if (matmulType == MatMulType::Batch) { + index.erase(index.begin()); + } + + // A Matrix load + broadcast + Value row = vector::LoadOp::create( + rewriter, loc, VectorType::get(outerVectSize, elementType), + outerMatrix, index); + valueRow = vector::BroadcastOp::create( + rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()), + row); + } else { + + // With the assumption as batch-reduce matmul initialize reduction, K, and + // N dimension. + SmallVector index = {c0, c0, indexOp_i}; + + // Remove reduction dimension if it is a batch matmul + if (matmulType == MatMulType::Batch) { + index.erase(index.begin()); + } + + // B Matrix load. + valueRow = vector::LoadOp::create( + rewriter, loc, VectorType::get(outerVectSize, elementType), + outerMatrix, index); + } + + matLoad.push_back(valueRow); + } + + // Load elements of A/B Matrix one at a time and compute FMA + for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) { + Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j); + Value valueRow; + + if (!isNTileLarge) { + SmallVector index = {c0, indexOp_j, c0}; + if (matmulType == MatMulType::Batch) { + index.erase(index.begin()); + } + + // A Matrix load + broadcast + Value row = vector::LoadOp::create( + rewriter, loc, VectorType::get(innerVectSize, elementType), + innerMatrix, ValueRange(index)); + valueRow = vector::BroadcastOp::create( + rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()), + row); + } else { + + SmallVector index = {c0, c0, indexOp_j}; + if (matmulType == MatMulType::Batch) { + index.erase(index.begin()); + } + + // B Matrix load + valueRow = vector::LoadOp::create( + rewriter, loc, VectorType::get(innerVectSize, elementType), + innerMatrix, index); + } + + // FMAs + for (unsigned int i = 0; i < fmaBound; i = i + 1) { + auto fmaOdd = + vector::FMAOp::create(rewriter, loc, matLoad[i], valueRow, acc[k]); + k++; + accumulators.push_back(fmaOdd); + } + } + + return accumulators; +} + +// Function to re-create K dimension loop with accumulator as IterArgs for +// lowering a batch-reduce vector contraction to a system specific nanokernels. +scf::ForOp createGEMMLoopsWithAccAsIterArgs( + RewriterBase &rewriter, Location loc, scf::ForOp kForOp, + vector::TransferReadOp vectorReadOpLhs, + vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp, + Type elementType, unsigned int vectorSize, unsigned int vnni, + unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp, + MatMulType matmulType) { + auto newKForOp = scf::ForOp::create( + rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), + kForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, + Value ivNewKForOp, ValueRange iterArgsNewKForOp) { + IRMapping mapping; + mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1), + ivNewReductionForOp); + mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3), + ivNewKForOp); + auto lhsClone = rewriterNewKForOp.clone( + *vectorReadOpLhs.getBase().getDefiningOp(), mapping); + + IRMapping rhsMapping; + rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1), + ivNewReductionForOp); + rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2), + ivNewKForOp); + auto rhsClone = rewriterNewKForOp.clone( + *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping); + + auto evenFMAs = generateNanokernels( + rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N, + iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0), + matmulType); + + scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs); + }); + + return newKForOp; +} + +// Function to re-create K dimension loop with accumulator as IterArgs for +// lowering a batch vector contraction to a system specific nanokernels. +scf::ForOp createGEMMLoopsWithAccAsIterArgs( + RewriterBase &rewriter, Location loc, scf::ForOp kForOp, + vector::TransferReadOp vectorReadOpLhs, + vector::TransferReadOp vectorReadOpRhs, Type elementType, + unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N, + ValueRange iterArgsNewReductionForOp, MatMulType matmulType) { + + auto newKForOp = scf::ForOp::create( + rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), + kForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, + Value ivNewKForOp, ValueRange iterArgsNewKForOp) { + IRMapping mapping; + mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(2), + ivNewKForOp); + auto lhsClone = rewriterNewKForOp.clone( + *vectorReadOpLhs.getBase().getDefiningOp(), mapping); + + IRMapping rhsMapping; + rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1), + ivNewKForOp); + auto rhsClone = rewriterNewKForOp.clone( + *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping); + + auto evenFMAs = + generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M, + N, iterArgsNewKForOp, lhsClone->getResult(0), + rhsClone->getResult(0), matmulType); + + scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs); + }); + + return newKForOp; +} + +Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc, + VectorType vecType, + SmallVector FMAs, Value accVec, + unsigned int vectorSize, unsigned int M, + unsigned int N) { + + auto strides = rewriter.getI64ArrayAttr({1}); + bool isNTileLarge = (N / vectorSize) > M; + if (isNTileLarge) { + for (unsigned int j = 0, k = 0; j < (N / vectorSize); j++) { + for (unsigned int i = 0; i < M; i++) { + unsigned int off = (j * vectorSize) + (i * N); + auto offsets = rewriter.getI64ArrayAttr({off}); + accVec = vector::InsertStridedSliceOp::create( + rewriter, loc, vecType, FMAs[k], accVec, offsets, strides); + k++; + } + } + + } else { + for (unsigned int i = 0, k = 0; i < M * N; i = i + vectorSize) { + auto offsets = rewriter.getI64ArrayAttr({i}); + accVec = vector::InsertStridedSliceOp::create( + rewriter, loc, vecType, FMAs[k], accVec, offsets, strides); + k++; + } + } + return accVec; +} + +// Rewriter pattern for vector.contract operation. +// Input: vector.contract with tiled dimensions (batch or batch-matmul) +// Matching Pattern: +// scf.for (0 to M) step m_tile { +// scf.for (0 to N) step n_tile { +// - Subview of Accumulator matrix - eg., acc : memref +// - %read = vector.transfer_read memref to +// vector %1 = scf.for (0 to reduce) +// iter_args_reduce=%read step reduce_tile { +// %2 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile { +// - Subview of A and B matrix +// - Vector transfer read of A and B +// - %acc = Vector.contract %read_A %read_B %iter_args_k +// scf.yield %acc +// } +// scf.yield %2 +// } +// vector.transfer_write %2 into accmulator matrix +// } +// } +// +// +// Rewrite IR: +// scf.for (0 to M) step m_tile { +// scf.for (0 to N) step n_tile { +// - Subview of Accumulator matrix - eg., acc : memref +// - %a = (n_tile / vector_size) * m_tile; +// // load the accumulator matrix as vector +// - %0 = load acc[0][0-15] into vector<16xf32> +// - %1 = load acc[0][16-31] into vector<16xf32> +// - %2 = load acc[1][0-15] into vector<16xf32> +// . +// . +// . +// - %a = load acc[m_tile-1][*-n_tile-1] into vector<16xf32> +// %3 = scf.for (0 to reduce) iter_args_reduce=%0 to %a step reduce_tile { +// %4 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile { +// - emit nano kernels (as shown in commnets above +// generateNanokernels function) +// scf.yield %acc[0] to %acc[a-1] +// } +// scf.yield %4: [0] to [a-1] +// } +// %5 = vector.insert %3: [0] to [a-1] into vector +// vector.transfer_write %5 into accmulator matrix +// } +// } +struct VectorContractNanokernelLowering + : public OpRewritePattern { + VectorContractNanokernelLowering(MLIRContext *context, + std::optional vecSize) + : OpRewritePattern(context), + userVectorSize(vecSize) {} + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + auto loc = contractOp.getLoc(); + + unsigned int vectorSize = 8; + if (userVectorSize) + vectorSize = *userVectorSize; + + if (contractOp.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + } + + SmallVector contractIteratorTypes = + contractOp.getIteratorTypesArray(); + + unsigned int reductionCount = + std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(), + vector::IteratorType::reduction); + + MatMulType matmulType = MatMulType::Others; + + if (reductionCount == 1) + matmulType = MatMulType::Batch; + + if (reductionCount == 2) + matmulType = MatMulType::BatchReduce; + + if ((matmulType != MatMulType::BatchReduce) && + (matmulType != MatMulType::Batch)) + return rewriter.notifyMatchFailure( + contractOp, "Expects batch-reduce or batch matmuls"); + + // Get the M, N, K, and batch-reduce loops + auto loops = getTiledMatmulLoopNest(contractOp, matmulType); + if (failed(loops)) + return rewriter.notifyMatchFailure( + contractOp, "Invalid loop nest in contract pattern"); + + auto nestedLoops = *loops; + scf::ForOp kForOp = nestedLoops[0]; + scf::ForOp reductionForOp; + + if (contractOp.getAcc().getDefiningOp()) { + return rewriter.notifyMatchFailure( + contractOp, "The Accumulator matrix should be hoisted outside the K " + "or reduction loop"); + } + + vector::TransferReadOp vectorReadOpAcc; + + if (matmulType == MatMulType::BatchReduce) { + reductionForOp = nestedLoops[1]; + vectorReadOpAcc = reductionForOp.getInitArgs()[0] + .getDefiningOp(); + } + + if (matmulType == MatMulType::Batch) { + vectorReadOpAcc = + kForOp.getInitArgs()[0].getDefiningOp(); + } + + auto vectorReadOpLhs = + contractOp.getLhs().getDefiningOp(); + auto vectorReadOpRhs = + contractOp.getRhs().getDefiningOp(); + + if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs) + return failure(); + + auto subviewOpAcc = + vectorReadOpAcc.getOperand(0).getDefiningOp(); + auto subviewOpLhs = + vectorReadOpLhs.getOperand(0).getDefiningOp(); + auto subviewOpRhs = + vectorReadOpRhs.getOperand(0).getDefiningOp(); + + if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs) + return failure(); + + SmallVector subviews; + subviews.push_back(subviewOpLhs); + subviews.push_back(subviewOpRhs); + subviews.push_back(subviewOpAcc); + + // The M, N, K, and batch-reduce loop iv should match the iv's + // used in the subviews + auto checkLoops = + checkMatmulLoopAndSubviewOffsetsMatching(*loops, subviews, matmulType); + if (failed(checkLoops)) + return rewriter.notifyMatchFailure( + contractOp, "tiled loops doesn't match the iv in subviews"); + + auto elementType = + (cast(subviewOpLhs.getType())).getElementType(); + + // TODO: Support for BF16 Type + if (!elementType.isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only, FP32 type is supported"); + + auto lhsType = dyn_cast(vectorReadOpLhs.getType()); + auto rhsType = dyn_cast(vectorReadOpRhs.getType()); + + // Get M, N, and K dimension size + unsigned int M = lhsType.getDimSize(lhsType.getRank() - 2); + unsigned int N = rhsType.getDimSize(rhsType.getRank() - 1); + unsigned int K = lhsType.getDimSize(lhsType.getRank() - 1); + unsigned int vnni = 1; + + if (K != 1) + return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1"); + + if (matmulType == MatMulType::BatchReduce && + lhsType.getDimSize(lhsType.getRank() - 3) != 1) + return rewriter.notifyMatchFailure(contractOp, + "The reduction-dim should be 1"); + + if (matmulType == MatMulType::BatchReduce) + rewriter.setInsertionPoint(reductionForOp); + + if (matmulType == MatMulType::Batch) + rewriter.setInsertionPoint(kForOp); + + // Load MxN C sub matrix into acc vectors (e.g, ) + SmallVector accumulators = loadAccumulatorBeforeGEMM( + loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc); + + // Create the batch-reduce and K-loop with acc vectors as the loop + // iterargs (batch-reduce matmul) + nanokernel generation + scf::ForOp newLoop; + if (matmulType == MatMulType::BatchReduce) { + newLoop = scf::ForOp::create( + rewriter, reductionForOp.getLoc(), reductionForOp.getLowerBound(), + reductionForOp.getUpperBound(), reductionForOp.getStep(), + accumulators, + [&](OpBuilder &rewriterNewReductionForOp, + Location locNewReductionForOp, Value ivNewReductionForOp, + ValueRange iterArgsNewReductionForOp) { + scf::ForOp newKForOp = createGEMMLoopsWithAccAsIterArgs( + rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, + ivNewReductionForOp, elementType, vectorSize, vnni, M, N, + iterArgsNewReductionForOp, matmulType); + + scf::YieldOp::create(rewriterNewReductionForOp, + locNewReductionForOp, newKForOp.getResults()); + }); + } + + // Create only the K-loop (batch matmul) + nanokernel generation + if (matmulType == MatMulType::Batch) { + newLoop = createGEMMLoopsWithAccAsIterArgs( + rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, elementType, + vectorSize, vnni, M, N, accumulators, matmulType); + } + + // Combine all output accumulator vectors into a m_tilexn_tile C matrix + auto vecType = VectorType::get({M * N}, rewriter.getF32Type()); + auto zeroAttr = + DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0)); + Value accVec = arith::ConstantOp::create(rewriter, loc, vecType, zeroAttr); + + accVec = mergeAccumulatedVectorAsMatrix( + rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N); + + auto accTy = dyn_cast(contractOp.getAccType()); + auto reshapeAcc = vector::ShapeCastOp::create(rewriter, loc, accTy, accVec); + + // Replace all the use of vector.contract with results of nanokernels + if (matmulType == MatMulType::BatchReduce) + rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc); + + if (matmulType == MatMulType::Batch) + rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc); + + return success(); + } + std::optional userVectorSize; +}; + +void x86vector::populateVectorContractNanokernelLoweringPatterns( + RewritePatternSet &patterns, std::optional userVectorSize) { + patterns.add(patterns.getContext(), + userVectorSize); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 3839172fd0b42..efcd09fc1b924 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" @@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + x86vector::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir new file mode 100644 index 0000000000000..3514358633c0d --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir @@ -0,0 +1,215 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +module { + func.func @fp32_batch_reduce_matmul_vector_size_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c96 = arith.constant 96 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.for %arg3 = %c0 to %c4 step %c4 { + scf.for %arg4 = %c0 to %c96 step %c96 { + %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>> + %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32> + %2 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %1) -> (vector<4x96xf32>) { + %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x96xf32>) { + %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 96] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>> + %4 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32> + %5 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x96xf32> + %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x96xf32> into vector<4x96xf32> + scf.yield %6 : vector<4x96xf32> + } + scf.yield %3 : vector<4x96xf32> + } + vector.transfer_write %2, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>> + } + } + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @fp32_batch_reduce_matmul_vector_size_16( +// CHECK-COUNT-24: vector.fma{{.*}}vector<16xf32> +// CHECK-NOT: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16 + } : !transform.any_op + transform.yield + } +} + +// ----- + +module { + func.func @fp32_batch_matmul_vector_size_8(%arg0: memref<4x32xf32>, %arg1: memref<32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c96 = arith.constant 96 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.for %arg3 = %c0 to %c4 step %c4 { + scf.for %arg4 = %c0 to %c96 step %c96 { + %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>> + %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32> + + %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %1) -> (vector<4x96xf32>) { + %subview_0 = memref.subview %arg0[%arg3, %arg7] [4, 1] [1, 1] : memref<4x32xf32> to memref<4x1xf32, strided<[32, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg7, %arg4] [1, 96] [1, 1] : memref<32x96xf32> to memref<1x96xf32, strided<[96, 1], offset: ?>> + %4 = vector.transfer_read %subview_0[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x1xf32, strided<[32, 1], offset: ?>>, vector<4x1xf32> + %5 = vector.transfer_read %subview_1[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x96xf32, strided<[96, 1], offset: ?>>, vector<1x96xf32> + %6 = vector.contract {indexing_maps = [affine_map<(d1, d2, d3) -> (d1, d3)>, affine_map<(d1, d2, d3) -> (d3, d2)>, affine_map<(d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<4x1xf32>, vector<1x96xf32> into vector<4x96xf32> + scf.yield %6 : vector<4x96xf32> + } + + vector.transfer_write %3, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>> + } + } + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @fp32_batch_matmul_vector_size_8( +// CHECK-COUNT-48: vector.fma{{.*}}vector<8xf32> +// CHECK-NOT: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 8 + } : !transform.any_op + transform.yield + } +} + +// ----- + +module { + func.func @negative_not_tiled(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %c0 = arith.constant 0 : index + %0 = ub.poison : f32 + %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x32xf32>, vector<1x4x32xf32> + %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x96xf32>, vector<1x32x96xf32> + %3 = vector.transfer_read %arg2[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32>, vector<4x96xf32> + %4 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x32xf32>, vector<1x32x96xf32> into vector<4x96xf32> + vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32> + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @negative_not_tiled( +// CHECK-NOT: vector.fma{{.*}}vector<8xf32> +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 8 + } : !transform.any_op + transform.yield + } +} + +// ----- + +module { + func.func @negative_tensor_type(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %1 = scf.for %arg3 = %c0 to %c32 step %c4 iter_args(%arg4 = %arg2) -> (tensor<32x32xf32>) { + %2 = scf.for %arg5 = %c0 to %c32 step %c16 iter_args(%arg6 = %arg4) -> (tensor<32x32xf32>) { + %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (tensor<32x32xf32>) { + %4 = scf.for %arg9 = %c0 to %c32 step %c1 iter_args(%arg10 = %arg8) -> (tensor<32x32xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg7, %arg3, %arg9] [1, 4, 1] [1, 1, 1] : tensor<32x32x32xf32> to tensor<1x4x1xf32> + %extracted_slice_0 = tensor.extract_slice %arg1[%arg7, %arg9, %arg5] [1, 1, 16] [1, 1, 1] : tensor<32x32x32xf32> to tensor<1x1x16xf32> + %extracted_slice_1 = tensor.extract_slice %arg10[%arg3, %arg5] [4, 16] [1, 1] : tensor<32x32xf32> to tensor<4x16xf32> + %5 = vector.transfer_read %extracted_slice[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : tensor<1x4x1xf32>, vector<1x4x1xf32> + %6 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : tensor<1x1x16xf32>, vector<1x1x16xf32> + %7 = vector.transfer_read %extracted_slice_1[%c0, %c0], %0 {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> + %8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %5, %6, %7 : vector<1x4x1xf32>, vector<1x1x16xf32> into vector<4x16xf32> + %9 = vector.transfer_write %8, %extracted_slice_1[%c0, %c0] {in_bounds = [true, true]} : vector<4x16xf32>, tensor<4x16xf32> + %inserted_slice = tensor.insert_slice %9 into %arg10[%arg3, %arg5] [4, 16] [1, 1] : tensor<4x16xf32> into tensor<32x32xf32> + scf.yield %inserted_slice : tensor<32x32xf32> + } + scf.yield %4 : tensor<32x32xf32> + } + scf.yield %3 : tensor<32x32xf32> + } + scf.yield %2 : tensor<32x32xf32> + } + return %1 : tensor<32x32xf32> + } +} + +// CHECK-LABEL: func.func @negative_tensor_type( +// CHECK-NOT: vector.fma{{.*}}vector<16xf32> +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16 + } : !transform.any_op + transform.yield + } +} + +// ----- + +module { + func.func @negative_accumulator_not_hoisted_outside_K_or_reduction_loop(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c96 = arith.constant 96 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + scf.for %arg3 = %c0 to %c4 step %c4 { + scf.for %arg4 = %c0 to %c96 step %c32 { + %subview = memref.subview %arg2[%arg3, %arg4] [4, 32] [1, 1] : memref<4x96xf32> to memref<4x32xf32, strided<[96, 1], offset: ?>> + scf.for %arg5 = %c0 to %c1 step %c1 { + scf.for %arg6 = %c0 to %c32 step %c1 { + %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4] [1, 1, 32] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x32xf32, strided<[3072, 96, 1], offset: ?>> + %1 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32> + %2 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x32xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x32xf32> + %3 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x32xf32, strided<[96, 1], offset: ?>>, vector<4x32xf32> + %4 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x32xf32> into vector<4x32xf32> + vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x32xf32>, memref<4x32xf32, strided<[96, 1], offset: ?>> + } + } + } + } + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @negative_accumulator_not_hoisted_outside_K_or_reduction_loop( +// CHECK-NOT: vector.fma{{.*}}vector<16xf32> +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16 + } : !transform.any_op + transform.yield + } +}