Skip to content

Commit

Permalink
2d Arm Neon sdot op, and lowering to the intrinsic.
Browse files Browse the repository at this point in the history
This adds Sdot2d op, which is similar to the usual Neon
intrinsic except that it takes 2d vector operands, reflecting the
structure of the arithmetic that it's performing: 4 separate
4-dimensional dot products, whence the vector<4x4xi8> shape.

This also adds a new pass, arm-neon-2d-to-intr, lowering
this new 2d op to the 1d intrinsic.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D102504
  • Loading branch information
bjacob authored and asaadaldien committed Jun 10, 2021
1 parent 4f01122 commit 20daeda
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 0 deletions.
30 changes: 30 additions & 0 deletions mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- ArmNeon2dToIntr.h - convert Arm Neon 2d ops to intrinsics ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_
#define MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_

#include "mlir/Pass/Pass.h"

namespace mlir {
class FuncOp;
template <typename T>
class OperationPass;

/// Populates patterns for the lowering of Arm NEON 2D ops to intrinsics.
/// See createConvertArmNeon2dToIntrPass.
void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns);

/// Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e.
/// equivalent ops operating on flattened 1D vectors and mapping more
/// directly to the corresponding Arm NEON instruction.
std::unique_ptr<OperationPass<FuncOp>> createConvertArmNeon2dToIntrPass();

} // namespace mlir

#endif // MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_CONVERSION_PASSES_H

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -607,4 +607,15 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// ArmNeon2dToIntr
//===----------------------------------------------------------------------===//

def ConvertArmNeon2dToIntr : Pass<"arm-neon-2d-to-intr", "FuncOp"> {
let summary = "Convert Arm NEON structured ops to intrinsics";
let constructor = "mlir::createConvertArmNeon2dToIntrPass()";
let dependentDialects = ["arm_neon::ArmNeonDialect", "vector::VectorDialect"];
}


#endif // MLIR_CONVERSION_PASSES
55 changes: 55 additions & 0 deletions mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// ArmNeon dialect definition
Expand Down Expand Up @@ -117,4 +118,58 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
}

class ArmNeon_2dOp<string mnemonic, list<OpTrait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
/*traits=*/traits>;

def Sdot2dOp : ArmNeon_2dOp<"sdot", [
NoSideEffect,
AllTypesMatch<["b", "c"]>,
AllTypesMatch<["a", "res"]>,
PredOpTrait<
"operand `a` should be 1-dimensional",
CPred<"a().getType().cast<VectorType>().getShape().size() == 1">
>,
PredOpTrait<
"operand `b` should be 2-dimensional",
CPred<"b().getType().cast<VectorType>().getShape().size() == 2">
>,
PredOpTrait<
"operand `b` should have 4 columns",
CPred<"b().getType().cast<VectorType>().getShape()[1] == 4">
>,
PredOpTrait<
"operand `b` should have as many rows as the size of operand `a`",
CPred<"b().getType().cast<VectorType>().getShape()[0] == a().getType().cast<VectorType>().getShape()[0]">
>,
]
> {
let summary = "sdot op";
let description = [{
The two input vectors `b` and `c` have a 2D shape, consisting of either 2
or 4 rows, each row having length 4. This operation computes the pair-wise
dot-products of the rows of `b` and `c` and accumulates them with the
corresponding entry of `a`:

```
res[i] := a[i] + dot_product(b[i, ...], c[i, ...])
```

}];
// Supports either:
// (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32>
// (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32>
// TODO: how do we express 2D shape requirements here?
let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a,
VectorOfLengthAndType<[16, 8], [I8]>:$b,
VectorOfLengthAndType<[16, 8], [I8]>:$c);
let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res);
let assemblyFormat =
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
let extraClassDeclaration = [{
static constexpr int kReductionSize = 4;
}];
}

#endif // ARMNEON_OPS
75 changes: 75 additions & 0 deletions mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
#include "../PassDetail.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::arm_neon;

namespace {

class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
public:
using OpRewritePattern::OpRewritePattern;

/// Convert to 1-dimensional vector type to match the requirements of
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
Type elemType = op.b().getType().cast<VectorType>().getElementType();
int length = op.b().getType().cast<VectorType>().getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
Value b2d = op.b();
Value c2d = op.c();
Location loc = op.getLoc();
Value b1d =
rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
Value c1d =
rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
Value newOp =
rewriter.create<SdotOp>(loc, op.res().getType(), op.a(), b1d, c1d);
rewriter.replaceOp(op, {newOp});
return success();
}
};

class ConvertArmNeon2dToIntr
: public ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();

RewritePatternSet patterns(context);
populateConvertArmNeon2dToIntrPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

namespace mlir {

void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) {
patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
}

std::unique_ptr<OperationPass<FuncOp>> createConvertArmNeon2dToIntrPass() {
return std::make_unique<ConvertArmNeon2dToIntr>();
}

} // namespace mlir
18 changes: 18 additions & 0 deletions mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_conversion_library(MLIRArmNeon2dToIntr
ArmNeon2dToIntr.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeon2dToIntr

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRArmNeon
MLIRPass
MLIRTransforms
MLIRIR
)
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(AffineToStandard)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(AsyncToLLVM)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToStandard)
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ namespace vector {
class VectorDialect;
} // end namespace vector

namespace arm_neon {
class ArmNeonDialect;
} // end namespace arm_neon

#define GEN_PASS_CLASSES
#include "mlir/Conversion/Passes.h.inc"

Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/ArmNeon/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// -----

func @a_is_2d(%a : vector<2x2xi32>, %b : vector<4x4xi8>) -> vector<2x2xi32> {
// expected-error@+1 {{operand `a` should be 1-dimensional}}
%0 = arm_neon.2d.sdot %a, %b, %b : vector<4x4xi8>, vector<4x4xi8> to vector<2x2xi32>
return %0 : vector<2x2xi32>
}

// -----

func @b_is_3d(%a : vector<4xi32>, %b : vector<1x4x4xi8>) -> vector<4xi32> {
// expected-error@+1 {{operand `b` should be 2-dimensional}}
%0 = arm_neon.2d.sdot %a, %b, %b : vector<1x4x4xi8>, vector<1x4x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func @b_has_2_columns(%a : vector<4xi32>, %b : vector<4x2xi8>) -> vector<4xi32> {
// expected-error@+1 {{operand `b` should have 4 columns}}
%0 = arm_neon.2d.sdot %a, %b, %b : vector<4x2xi8>, vector<4x2xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi8>) -> vector<4xi32> {
// expected-error@+1 {{operand `b` should have as many rows as the size of operand `a`}}
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
17 changes: 17 additions & 0 deletions mlir/test/Target/LLVMIR/arm-neon-2d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt -arm-neon-2d-to-intr %s | FileCheck %s

// CHECK-LABEL: arm_neon_sdot2d_4x4_i8i8
func @arm_neon_sdot2d_4x4_i8i8(%a: vector<4xi32>, %b: vector<4x4xi8>, %c: vector<4x4xi8>) -> vector<4xi32> {
// CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<16xi8>, vector<16xi8> to vector<4xi32>
// CHECK-NEXT: return %{{.*}} : vector<4xi32>
%0 = arm_neon.2d.sdot %a, %b, %c : vector<4x4xi8>, vector<4x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// CHECK-LABEL: arm_neon_sdot2d_2x4_i8i8
func @arm_neon_sdot2d_2x4_i8i8(%a: vector<2xi32>, %b: vector<2x4xi8>, %c: vector<2x4xi8>) -> vector<2xi32> {
// CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<8xi8>, vector<8xi8> to vector<2xi32>
// CHECK-NEXT: return %{{.*}} : vector<2xi32>
%0 = arm_neon.2d.sdot %a, %b, %c : vector<2x4xi8>, vector<2x4xi8> to vector<2xi32>
return %0 : vector<2xi32>
}

0 comments on commit 20daeda

Please sign in to comment.