-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA #81895
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
b1c60ab
to
e7efc97
Compare
eed833d
to
8ce5a14
Compare
b2b6eea
to
2341c57
Compare
@llvm/pr-subscribers-mlir-neon @llvm/pr-subscribers-mlir-core Author: Kojo Acquah (KoolJBlack) ChangesThis patch adds a the This pattern inspects Full diff: https://github.com/llvm/llvm-project/pull/81895.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
new file mode 100644
index 00000000000000..41dbc2633d52c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -0,0 +1,21 @@
+//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
+
+namespace mlir {
+
+namespace arm_neon {
+void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
+} // namespace arm_neon
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 5fbb50f62395ec..a0fce139f27466 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSMEDialect
MLIRArmSMETransforms
MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 060b6df1b334ad..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRArmNeonDialect
- IR/ArmNeonDialect.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
-
- DEPENDS
- MLIRArmNeonIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSideEffectInterfaces
- )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b04919a3a31858
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeonDialect
+ ArmNeonDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+ DEPENDS
+ MLIRArmNeonIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+ )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..dcd806e981479d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmNeonTransforms
+ LowerVectorToArmNeon.cpp
+
+ DEPENDS
+ MLIRArmNeonIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmNeonDialect
+ MLIRFuncDialect
+ MLIRVectorDialect
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ )
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
new file mode 100644
index 00000000000000..f3ac36d08fc82b
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp
@@ -0,0 +1,153 @@
+//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// arm_neon.intr.smmla
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-neon-lower-vector"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+
+// Return the shaped type with new element type.
+static Type matchContainerType(Type element, Type container) {
+ if (auto shapedTy = dyn_cast<ShapedType>(container))
+ return shapedTy.clone(element);
+
+ return element;
+}
+
+// Lowering from vector::contractOp directly to the arm neon
+// intrinsic.
+class LowerVectorToArmNeonPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+ Value res = op.getAcc();
+
+ // Check index maps represent M N K and aren't transposed.
+ auto indexingMaps = op.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
+ return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
+ affineMap.getNumResults() != 2;
+ })) {
+ return failure();
+ }
+
+ // Check iterator types for contract
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() != 3 ||
+ iteratorTypes[0] != vector::IteratorType::parallel ||
+ iteratorTypes[1] != vector::IteratorType::parallel ||
+ iteratorTypes[2] != vector::IteratorType::reduction) {
+ return failure();
+ }
+
+ // Check the tile size by mapping the dimensions of the contract
+ // -- Tile size: [2, 2, 8]
+ // Infer tile sizes from operands. Check required tile size
+ // Note: RHS is not transposed
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+ auto dimM = lhsType.getDimSize(0);
+ auto dimN = rhsType.getDimSize(0);
+ auto dimK = lhsType.getDimSize(1);
+ if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+ return failure();
+ }
+
+ // Check two extsi inputs Rhs Lhs
+ arith::ExtSIOp origLhsExtOp;
+ arith::ExtSIOp origRhsExtOp;
+ if (!(origLhsExtOp =
+ dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
+ !(origRhsExtOp =
+ dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
+ return failure();
+ }
+
+ arith::ExtSIOp extsiLhs;
+ arith::ExtSIOp extsiRhs;
+ // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
+ // following neon instruction. Check inputs for extsi are <=i8
+ if (auto lhsExtInType =
+ origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+ if (lhsExtInType.getElementTypeBitWidth() <= 8) {
+ // Target lhs type with i8. This is likely redundant
+ Type targetLhsExtTy =
+ matchContainerType(rewriter.getI8Type(), lhsExtInType);
+ extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
+ origLhsExtOp.getIn());
+ }
+ }
+ if (auto rhsExtInType =
+ origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+ if (rhsExtInType.getElementTypeBitWidth() <= 8) {
+ // Target rhs type with i8
+ Type targetRhsExtTy =
+ matchContainerType(rewriter.getI8Type(), rhsExtInType);
+ extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
+ origRhsExtOp.getIn());
+ }
+ }
+
+ if (!extsiLhs || !extsiRhs) {
+ return failure();
+ }
+
+ // Collapse to 1D vectors required by smmla intrinsic
+ auto collapsedInputType = VectorType::get(
+ {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
+ auto collapsedOutputType =
+ VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
+ auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
+ extsiLhs.getLoc(), collapsedInputType, extsiLhs);
+ auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
+ extsiRhs.getLoc(), collapsedInputType, extsiRhs);
+ auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
+ res.getLoc(), collapsedOutputType, res);
+
+ // Replace the contract with a neon op
+ auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
+ op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+ collapsedRhs);
+
+ // Reshape output back to 2D
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
+ smmlaOp);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
+}
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
new file mode 100644
index 00000000000000..fe7259a9919ccf
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -test-lower-vector-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: test_lower_vector_arm_neon_mixed_types
+// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
+// CHECK-DAG: %[[D0:.*]] = arith.extsi %[[A1]] : vector<2x8xi4> to vector<2x8xi8>
+// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[D0]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG: %[[D3:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG: %[[D4:.*]] = arm_neon.intr.smmla %[[D3]], %[[D1]], %[[D2]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG: %[[D5:.*]] = vector.shape_cast %[[D4]] : vector<4xi32> to vector<2x2xi32>
+func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+ return %res : vector<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 00000000000000..21548ca57701f9
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,13 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRArmNeonTestPasses
+ TestLowerToArmNeon.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRArmNeonDialect
+ MLIRArmNeonTransforms
+ MLIRIR
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
new file mode 100644
index 00000000000000..6398b69cc82816
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
@@ -0,0 +1,63 @@
+//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to ArmNeon as a
+// generally usable sink pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define PASS_NAME "test-lower-vector-to-arm-neon"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+struct TestLowerToArmNeon
+ : public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
+
+ StringRef getArgument() const final { return PASS_NAME; }
+ StringRef getDescription() const final {
+ return "Tests lower vector to arm Neon.";
+ }
+ TestLowerToArmNeon() = default;
+ TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arm_neon::ArmNeonDialect>();
+ }
+
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void TestLowerToArmNeon::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ populateLowerVectorToArmNeonPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+}
+
+namespace mlir {
+namespace test {
+
+void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index e20cd4473a3580..29fb4441a24fd2 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
add_subdirectory(Affine)
add_subdirectory(Arith)
+add_subdirectory(ArmNeon)
add_subdirectory(ArmSME)
add_subdirectory(Bufferization)
add_subdirectory(ControlFlow)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 701fc461b3b4e9..d4faa087bc6a0c 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,6 +17,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestFuncToLLVM
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
+ MLIRArmNeonTestPasses
MLIRArmSMETestPasses
MLIRBufferizationTestPasses
MLIRControlFlowTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0ba1a3a534e35c..c099bbd97eaecf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -110,6 +110,7 @@ void registerTestLoopFusion();
void registerTestCFGLoopInfoPass();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
+void registerTestLowerToArmNeon();
void registerTestLowerToArmSME();
void registerTestLowerToLLVM();
void registerTestMakeIsolatedFromAbovePass();
@@ -236,6 +237,7 @@ void registerTestPasses() {
mlir::test::registerTestCFGLoopInfoPass();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
+ mlir::test::registerTestLowerToArmNeon();
mlir::test::registerTestLowerToArmSME();
mlir::test::registerTestLowerToLLVM();
mlir::test::registerTestMakeIsolatedFromAbovePass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7f33f165992213..59dc1100ea8df5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1929,6 +1929,27 @@ cc_library(
],
)
+cc_library(
+ name = "ArmNeonTransforms",
+ srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
+ hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
+ includes = ["include"],
+ deps = [
+ ":ArithDialect",
+ ":ArmNeonIncGen",
+ ":ArmNeonDialect",
+ ":FuncDialect",
+ ":IR",
+ ":LLVMDialect",
+ ":SideEffectInterfaces",
+ ":Support",
+ ":VectorDialect",
+ ":Transforms",
+ "//llvm:Core",
+ "//llvm:Support",
+ ],
+)
+
gentbl_cc_library(
name = "ArmNeonConversionIncGen",
tbl_outs = [
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Look great overall! A few minor comments for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I would suggest that you use createOrFold
for those extensions. It looks great otherwise!
It looks like the
I'm not sure how this relates to this PR. Its only the windows runner. Perhaps an unrelated issue? @dcaballe |
726d179
to
77a0947
Compare
This patch adds a the
LowerVectorToArmNeonPattern
patterns to the ArmNeon.This pattern inspects
vector.contract
ops that can be 1-1 mapped to anarm.neon.smmla
intrinsic. The contract ops must be separated into tiles who's inputs must fit that of a single smmla op (2x8xi32
inputs and2x2xi32
output). Thevector.contract
inputs must be sign extended from narrow types (<=i8) to be converted. If all conditions are met, an smmla op is inserted with additionalvector.shape_casts
to handle linearizing the input and output dimension.