Skip to content

Commit

Permalink
[mlir][tosa] Add tosa.max_pool2d lowering to linalg int max pooling a…
Browse files Browse the repository at this point in the history
…dditions

Lowerings tosa.max_pool2d to linalg equivalent operations. Includes
adding max pooling operations for linalg, with corresponding tests.

Differential Revision: https://reviews.llvm.org/D99824
  • Loading branch information
rsuderman committed Apr 9, 2021
1 parent 4a84b03 commit ceeb5b0
Show file tree
Hide file tree
Showing 6 changed files with 424 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,51 @@ def pooling_nhwc_sum
ow * strides[1] + kw * dilations[1], c));
}

ods_def<PoolingNHWCMaxI8Op>:
def pooling_nhwc_i8_max
(I: i8(N, H, W, C), K: i8(KH, KW))
-> (O: i8(N, OH, OW, C))
attr(strides: 2xi64, dilations: 2xi64)
{
O(n, oh, ow, c) =
std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
ow * strides[1] + kw * dilations[1], c),
O(n, oh, ow, c)),
I(n, oh * strides[0] + kh * dilations[0],
ow * strides[1] + kw * dilations[1], c),
O(n, oh, ow, c));
}

ods_def<PoolingNHWCMaxI16Op>:
def pooling_nhwc_i16_max
(I: i16(N, H, W, C), K: i16(KH, KW))
-> (O: i16(N, OH, OW, C))
attr(strides: 2xi64, dilations: 2xi64)
{
O(n, oh, ow, c) =
std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
ow * strides[1] + kw * dilations[1], c),
O(n, oh, ow, c)),
I(n, oh * strides[0] + kh * dilations[0],
ow * strides[1] + kw * dilations[1], c),
O(n, oh, ow, c));
}

ods_def<PoolingNHWCMaxI32Op>:
def pooling_nhwc_i32_max
(I: i32(N, H, W, C), K: i32(KH, KW))
-> (O: i32(N, OH, OW, C))
attr(strides: 2xi64, dilations: 2xi64)
{
O(n, oh, ow, c) =
std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
ow * strides[1] + kw * dilations[1], c),
O(n, oh, ow, c)),
I(n, oh * strides[0] + kh * dilations[0],
ow * strides[1] + kw * dilations[1], c),
O(n, oh, ow, c));
}

ods_def<PoolingNHWCMaxFOp>:
def pooling_nhwc_max
(I: f32(N, H, W, C), K: f32(KH, KW))
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ struct CmpFValueBuilder : public ValueBuilder<CmpFOp> {
using std_cmpf_ogt = CmpFValueBuilder<CmpFPredicate::OGT>;
using std_cmpf_olt = CmpFValueBuilder<CmpFPredicate::OLT>;

template <CmpIPredicate Predicate>
struct CmpIValueBuilder : public ValueBuilder<CmpIOp> {
using ValueBuilder<CmpIOp>::ValueBuilder;
template <typename... Args>
CmpIValueBuilder(Args... args) : ValueBuilder<CmpIOp>(Predicate, args...) {}
};

using std_cmpi_sgt = CmpIValueBuilder<CmpIPredicate::sgt>;

/// Branches into `block` with `operands`.
BranchOp std_br(Block *block, ValueRange operands);

Expand Down
155 changes: 139 additions & 16 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,22 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
"Pad converter requires static shaped input / padding values.");
}

Attribute constantAttr;
if (elementTy.isa<FloatType>())
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
auto value = padOp.quantization_info().getValue().input_zp().getValue();
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
}

if (!constantAttr) {
return rewriter.notifyMatchFailure(
padOp,
"tosa.pad to linalg lowering encountered an unknown element type");
}

Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
Value highIndex =
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
Expand All @@ -1256,22 +1272,6 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
highValues.push_back(highVal);
}

Attribute constantAttr;
if (elementTy.isa<FloatType>())
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
auto value = padOp.quantization_info().getValue().input_zp().getValue();
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
}

if (!constantAttr) {
return rewriter.notifyMatchFailure(
padOp,
"tosa.pad to linalg lowering encountered an unknown element type");
}

Value constant = rewriter.create<ConstantOp>(loc, constantAttr);

auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
Expand Down Expand Up @@ -1523,6 +1523,128 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
}
};

class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
public:
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.input();
ShapedType inputTy = input.getType().cast<ShapedType>();
Type inElementTy = inputTy.getElementType();

ShapedType resultTy = op.getType().cast<ShapedType>();
Type outElementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();

if (!inputTy.hasStaticShape())
return failure();

// Determine what the initial value needs to be for the max pool op.
Attribute initialAttr;
if (outElementTy.isF32())
initialAttr = rewriter.getFloatAttr(
outElementTy,
APFloat::getLargest(
outElementTy.cast<FloatType>().getFloatSemantics(), true));

if (outElementTy.isa<IntegerType>())
initialAttr = rewriter.getIntegerAttr(
outElementTy,
APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));

if (!initialAttr)
return rewriter.notifyMatchFailure(
op, "Unsupported initial value for tosa.maxpool_2d op");

Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);

SmallVector<int64_t> kernel, stride, pad;
getValuesFromIntArrayAttribute(op.kernel(), kernel);
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.pad(), pad);

Attribute strideAttr = rewriter.getI64VectorAttr(stride);
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});

// If non-zero padding we need to pad the input
if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
SmallVector<int64_t, 4> paddedShape;
for (int64_t i = 0; i < rank; i++)
paddedShape.push_back(inputTy.getDimSize(i));

paddedShape[1] += pad[0] + pad[1];
paddedShape[2] += pad[2] + pad[3];

OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]);
OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]);
OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]);
OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]);

SmallVector<OpFoldResult, 4> lowIndices = {zeroIndex, heightLowPadIndex,
widthLowPadIndex, zeroIndex};
SmallVector<OpFoldResult, 4> highIndices = {zeroIndex, heightHighPadIndex,
widthHighPadIndex, zeroIndex};

input = linalg::PadTensorOp::createPadScalarOp(
RankedTensorType::get(paddedShape, inElementTy), input,
initialValue, lowIndices, highIndices, loc, rewriter)
.result();
}

Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());

Value filledInitTensor =
rewriter.create<linalg::FillOp>(loc, initTensor, initialValue).result();

Value fakeWindowDims =
rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);

auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
return cast<linalg::LinalgOp>(
rewriter
.create<std::remove_pointer_t<decltype(typePtr)>>(
loc, ArrayRef<Type>{resultTy},
ValueRange{input, fakeWindowDims}, filledInitTensor,
dilationAttr, strideAttr)
.getOperation());
};

if (inElementTy.isF32()) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}

if (inElementTy.isInteger(8)) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}

if (inElementTy.isInteger(16)) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}

if (inElementTy.isInteger(32)) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}

return failure();
}
};

} // namespace

void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
Expand Down Expand Up @@ -1579,6 +1701,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
TileConverter,
TransposeConverter,
MatMulConverter,
MaxPool2dConverter,
FullyConnectedConverter>(patterns->getContext());
// clang-format on
}
50 changes: 50 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,53 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
%0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>)
return
}

// -----

// CHECK-LABEL: @max_pool
func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
// CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38
// CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62]
// CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
// CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
// CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>)
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>)
return
}

// CHECK-LABEL: @max_pool_padded
func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
// CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38 : f32
// CHECK-DAG: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 0, 0, 0] high[0, 0, 1, 0]
// CHECK-DAG: linalg.yield [[CONST]]
// CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62]
// CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
// CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
// CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>)
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x33x62xf32>)
return
}

// CHECK-LABEL: @max_pool_i8
func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
// CHECK: constant -128
// CHECK: linalg.pooling_nhwc_i8_max
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>)
return
}

// CHECK-LABEL: @max_pool_i16
func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
// CHECK: constant -32768
// CHECK: linalg.pooling_nhwc_i16_max
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi16>) -> (tensor<1x4x32x62xi16>)
return
}

// CHECK-LABEL: @max_pool_i32
func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
// CHECK: constant -2147483648
// CHECK: linalg.pooling_nhwc_i32_max
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
return
}
78 changes: 78 additions & 0 deletions mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,84 @@ func @pooling_nhwc_max(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %ini

// -----

func @pooling_nhwc_i8_max(%input: memref<?x?x?x?xi8>, %fake: memref<2x3xi8>, %init: memref<?x?x?x?xi8>) {
linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
ins(%input, %fake: memref<?x?x?x?xi8>, memref<2x3xi8>)
outs(%init: memref<?x?x?x?xi8>)
return
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK: func @pooling_nhwc_i8_max

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi8>, memref<2x3xi8>)
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi8>)

// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8)
// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i8
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i8
// CHECK-NEXT: linalg.yield %[[RES]] : i8

// -----

func @pooling_nhwc_i16_max(%input: memref<?x?x?x?xi16>, %fake: memref<2x3xi16>, %init: memref<?x?x?x?xi16>) {
linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
ins(%input, %fake: memref<?x?x?x?xi16>, memref<2x3xi16>)
outs(%init: memref<?x?x?x?xi16>)
return
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK: func @pooling_nhwc_i16_max

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi16>, memref<2x3xi16>)
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi16>)

// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i16, %[[BBARG1:.+]]: i16, %[[BBARG2:.+]]: i16)
// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i16
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i16
// CHECK-NEXT: linalg.yield %[[RES]] : i16

// -----

func @pooling_nhwc_i32_max(%input: memref<?x?x?x?xi32>, %fake: memref<2x3xi32>, %init: memref<?x?x?x?xi32>) {
linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
ins(%input, %fake: memref<?x?x?x?xi32>, memref<2x3xi32>)
outs(%init: memref<?x?x?x?xi32>)
return
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK: func @pooling_nhwc_i32_max

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi32>, memref<2x3xi32>)
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi32>)

// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32)
// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i32
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i32
// CHECK-NEXT: linalg.yield %[[RES]] : i32

// -----

func @pooling_nhwc_min(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %init: memref<?x?x?x?xf32>) {
linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
ins(%input, %fake: memref<?x?x?x?xf32>, memref<2x3xf32>)
Expand Down
Loading

0 comments on commit ceeb5b0

Please sign in to comment.