-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2d Arm Neon sdot op, and lowering to the intrinsic.
This adds Sdot2d op, which is similar to the usual Neon intrinsic except that it takes 2d vector operands, reflecting the structure of the arithmetic that it's performing: 4 separate 4-dimensional dot products, whence the vector<4x4xi8> shape. This also adds a new pass, arm-neon-2d-to-intr, lowering this new 2d op to the 1d intrinsic. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D102504
- Loading branch information
1 parent
4f01122
commit 20daeda
Showing
10 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
30 changes: 30 additions & 0 deletions
30
mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//===- ArmNeon2dToIntr.h - convert Arm Neon 2d ops to intrinsics ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ | ||
#define MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
class FuncOp; | ||
template <typename T> | ||
class OperationPass; | ||
|
||
/// Populates patterns for the lowering of Arm NEON 2D ops to intrinsics. | ||
/// See createConvertArmNeon2dToIntrPass. | ||
void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns); | ||
|
||
/// Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e. | ||
/// equivalent ops operating on flattened 1D vectors and mapping more | ||
/// directly to the corresponding Arm NEON instruction. | ||
std::unique_ptr<OperationPass<FuncOp>> createConvertArmNeon2dToIntrPass(); | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
//===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" | ||
#include "../PassDetail.h" | ||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" | ||
#include "mlir/Dialect/Vector/VectorOps.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassRegistry.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::arm_neon; | ||
|
||
namespace { | ||
|
||
class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
/// Convert to 1-dimensional vector type to match the requirements of | ||
/// arm.neon.intr.sdot | ||
LogicalResult matchAndRewrite(Sdot2dOp op, | ||
PatternRewriter &rewriter) const override { | ||
Type elemType = op.b().getType().cast<VectorType>().getElementType(); | ||
int length = op.b().getType().cast<VectorType>().getShape()[0] * | ||
Sdot2dOp::kReductionSize; | ||
VectorType flattenedVectorType = VectorType::get({length}, elemType); | ||
Value b2d = op.b(); | ||
Value c2d = op.c(); | ||
Location loc = op.getLoc(); | ||
Value b1d = | ||
rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d); | ||
Value c1d = | ||
rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d); | ||
Value newOp = | ||
rewriter.create<SdotOp>(loc, op.res().getType(), op.a(), b1d, c1d); | ||
rewriter.replaceOp(op, {newOp}); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ConvertArmNeon2dToIntr | ||
: public ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> { | ||
void runOnOperation() override { | ||
auto func = getOperation(); | ||
auto *context = &getContext(); | ||
|
||
RewritePatternSet patterns(context); | ||
populateConvertArmNeon2dToIntrPatterns(patterns); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace mlir { | ||
|
||
void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) { | ||
patterns.add<Sdot2dLoweringPattern>(patterns.getContext()); | ||
} | ||
|
||
std::unique_ptr<OperationPass<FuncOp>> createConvertArmNeon2dToIntrPass() { | ||
return std::make_unique<ConvertArmNeon2dToIntr>(); | ||
} | ||
|
||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
add_mlir_conversion_library(MLIRArmNeon2dToIntr | ||
ArmNeon2dToIntr.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeon2dToIntr | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRArmNeon | ||
MLIRPass | ||
MLIRTransforms | ||
MLIRIR | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | ||
|
||
// ----- | ||
|
||
func @a_is_2d(%a : vector<2x2xi32>, %b : vector<4x4xi8>) -> vector<2x2xi32> { | ||
// expected-error@+1 {{operand `a` should be 1-dimensional}} | ||
%0 = arm_neon.2d.sdot %a, %b, %b : vector<4x4xi8>, vector<4x4xi8> to vector<2x2xi32> | ||
return %0 : vector<2x2xi32> | ||
} | ||
|
||
// ----- | ||
|
||
func @b_is_3d(%a : vector<4xi32>, %b : vector<1x4x4xi8>) -> vector<4xi32> { | ||
// expected-error@+1 {{operand `b` should be 2-dimensional}} | ||
%0 = arm_neon.2d.sdot %a, %b, %b : vector<1x4x4xi8>, vector<1x4x4xi8> to vector<4xi32> | ||
return %0 : vector<4xi32> | ||
} | ||
|
||
// ----- | ||
|
||
func @b_has_2_columns(%a : vector<4xi32>, %b : vector<4x2xi8>) -> vector<4xi32> { | ||
// expected-error@+1 {{operand `b` should have 4 columns}} | ||
%0 = arm_neon.2d.sdot %a, %b, %b : vector<4x2xi8>, vector<4x2xi8> to vector<4xi32> | ||
return %0 : vector<4xi32> | ||
} | ||
|
||
// ----- | ||
|
||
func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi8>) -> vector<4xi32> { | ||
// expected-error@+1 {{operand `b` should have as many rows as the size of operand `a`}} | ||
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32> | ||
return %0 : vector<4xi32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// RUN: mlir-opt -arm-neon-2d-to-intr %s | FileCheck %s | ||
|
||
// CHECK-LABEL: arm_neon_sdot2d_4x4_i8i8 | ||
func @arm_neon_sdot2d_4x4_i8i8(%a: vector<4xi32>, %b: vector<4x4xi8>, %c: vector<4x4xi8>) -> vector<4xi32> { | ||
// CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<16xi8>, vector<16xi8> to vector<4xi32> | ||
// CHECK-NEXT: return %{{.*}} : vector<4xi32> | ||
%0 = arm_neon.2d.sdot %a, %b, %c : vector<4x4xi8>, vector<4x4xi8> to vector<4xi32> | ||
return %0 : vector<4xi32> | ||
} | ||
|
||
// CHECK-LABEL: arm_neon_sdot2d_2x4_i8i8 | ||
func @arm_neon_sdot2d_2x4_i8i8(%a: vector<2xi32>, %b: vector<2x4xi8>, %c: vector<2x4xi8>) -> vector<2xi32> { | ||
// CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<8xi8>, vector<8xi8> to vector<2xi32> | ||
// CHECK-NEXT: return %{{.*}} : vector<2xi32> | ||
%0 = arm_neon.2d.sdot %a, %b, %c : vector<2x4xi8>, vector<2x4xi8> to vector<2xi32> | ||
return %0 : vector<2xi32> | ||
} |