-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][x86vector] Lower vector.contract to FMA or packed type dot-product #168074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
This transform pass is intended to use at the last stage of vectorization after unrolling of |
|
cc: @rengolin @shahidact please have a look. |
|
cc: @rolfmorel |
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a short description somewhere of what shape the input is expected to be?
Also, potentially, what upstream transforms may produce it from Linalg.
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
I'm happy to hear it helped you in prototyping. Let's just focus on getting things running. And, of course, thanks for iterating on this vectorization work, looks really promising. |
mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
Show resolved
Hide resolved
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Arun Thangamani (arun-thmn) ChangesA The lowering works on condition with The lowering pattern: Patch is 41.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168074.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..4009a140bb097
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,42 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_to_fma",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operation can be lowered to a FMA.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
+
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..943d7182d1960 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
#include "mlir/IR/Value.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class ImplicitLocOpBuilder;
@@ -79,6 +83,13 @@ struct MaskHelper {
}
};
+//===----------------------------------------------------------------------===//
+
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
+void populateVectorContractToPackedTypeDotProductPatterns(
+ RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+ X86VectorTransformOps.cpp
+
+ DEPENDS
+ MLIRX86VectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRX86VectorDialect
+ MLIRX86VectorTransforms
+ )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..68d577326a308
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,65 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToFMAPatterns(patterns);
+}
+
+void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ X86VectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ X86VectorTransformDialectExtension)
+
+ X86VectorTransformDialectExtension() {
+ declareGeneratedDialect<x86vector::X86VectorDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..3d2288049e49e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
+ VectorContractToFMA.cpp
+ VectorContractToPackedTypeDotProduct.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 0000000000000..764ec46681094
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,99 @@
+//===- VectorContractToFMA.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements outer product contraction as a sequence of broadcast and
+// FMA operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <16xf32>
+// vector.fma vector<16xf32>
+// ```
+struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 lowering is supported.");
+ if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects one for all dimensions of LHS");
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> dimsRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsRhs.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ assert(accTy && "Invalid accumulator");
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> dimsAcc;
+ llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsAcc.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+
+ // Lowers vector.contract into a broadcast+FMA sequence.
+ auto loc = contractOp.getLoc();
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), castLhs);
+ auto fma =
+ vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+ auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+
+ rewriter.replaceOp(contractOp, castFma);
+
+ return success();
+ }
+};
+
+void x86vector::populateVectorContractToFMAPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
new file mode 100644
index 0000000000000..1dabbddbebb7e
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -0,0 +1,148 @@
+//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements packed type outer product contraction as a sequence
+// of broadcast and packed dot-product operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <32xbf16>
+// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
+// ```
+struct VectorContractToPackedTypeDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16() &&
+ !lhsTy.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only BF16/Int8 lowering is supported.");
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ if (lhsTy.getElementType().isBF16() && lhsShape.back() != 2)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The LHS vnni dim should be 2 for BF16.");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && lhsShape.back() != 4)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The LHS vnni dim should be 4 for Int8.");
+ llvm::SmallVector<int64_t> dimsLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsLhs.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular LHS shape");
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ if (lhsTy.getElementType().isBF16() && rhsShape.back() != 2)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The RHS vnni dim should be 2 for BF16.");
+ if (lhsTy.getElementType().isSignlessInteger(8) && rhsShape.back() != 4)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The RHS vnni dim should be 4 for Int8.");
+ llvm::SmallVector<int64_t> dimsRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsRhs.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ assert(accTy && "Invalid accumulator");
+ if (!accTy.getElementType().isF32() &&
+ !accTy.getElementType().isSignlessInteger(32))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only F32/Int32 accumulation is supported.");
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> dimsAcc;
+ llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsAcc.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+
+ auto loc = contractOp.getLoc();
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front() * dimsRhs.back(),
+ rhsTy.getElementType()),
+ contractOp.getRhs());
+
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto bitcastLhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castLhs);
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+ bitcastLhs);
+ auto bitcastLhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
+
+ Value dp;
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
+ bitcastLhsPkType, castRhs);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+
+ if (dp) {
+ auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
+ rewriter.replaceOp(contractOp, castDp);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index c857c38df717c..4312100a0c0b0 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
...
[truncated]
|
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good to me already. If @adam-smnk is happy, I'm happy too.
mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
🐧 Linux x64 Test Results
|
mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
adam-smnk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionality looks good.
Only minor housekeeping comments.
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
Show resolved
Hide resolved
adam-smnk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good 👍
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/28772 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/29960 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/28751 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/22218 Here is the relevant piece of the build log for the reference |
Adds required dependency for `inferContractionDims`. Fixes #168074
This is a second attempt to fix the bazel build (after the first in llvm#169294, which was accidentally merged before CI passed). In the first attempt, not all bazel dependencies had been added; this PR should add them all and make CI pass. Signed-off-by: Ingo Müller <ingomueller@google.com>
This is a second attempt to fix the bazel build (after the first in llvm#169294, which was accidentally merged before CI passed). In the first attempt, not all bazel dependencies had been added; this PR should add them all and make CI pass. Signed-off-by: Ingo Müller <ingomueller@google.com>
…duct (llvm#168074) A `transform` pass to lower `vector.contract` to (a) `vector.fma` for `F32`, (b) `x86vector.avx512.dot` for `BF16`, (c) `x86vector.avx.dot.i8` for `Int8` packed types. The lowering works on condition with `m`, `batch`, `k` dims to be `one` and `vnni` dim should be `2` for `bf16`; `4` for `int8`. **The lowering pattern**: `batch_reduce.matmul` (input) -> register-tiling(M, N) -> Vectorization (to `vector.contract`) -> `unroll` vector.contract (`unit` dims) -> `hoisting` transformation (move `C` loads/store outside batch/k loop) -> apply `licm`, `canonicalization`, and `bufferize`.
Adds required dependency for `inferContractionDims`. Fixes llvm#168074
…69294) This PR fixes the bazel build that went out of sync with the changes introduced in llvm#168074. Signed-off-by: Ingo Müller <ingomueller@google.com>
…llvm#169316) This is a second attempt to fix the bazel build (after the first in llvm#169294, which was accidentally merged before CI passed). In the first attempt, not all bazel dependencies had been added; this PR should add them all and make CI pass. Signed-off-by: Ingo Müller <ingomueller@google.com>
A
transformpass to lowervector.contractto (a)vector.fmaforF32, (b)x86vector.avx512.dotforBF16, (c)x86vector.avx.dot.i8forInt8packed types.The lowering works on condition with
m,batch,kdims to beoneandvnnidim should be2forbf16;4forint8.The lowering pattern:
batch_reduce.matmul(input) -> register-tiling(M, N) -> Vectorization (tovector.contract) ->unrollvector.contract (unitdims) ->hoistingtransformation (moveCloads/store outside batch/k loop) -> applylicm,canonicalization, andbufferize.