Skip to content

Commit

Permalink
[mlir][tosa] Add tosa.pad to linalg.pad operation
Browse files Browse the repository at this point in the history
Lowers from tosa's pad op to the linalg equivalent for floating,
integer, and quantized values.

Differential Revision: https://reviews.llvm.org/D98990
  • Loading branch information
rsuderman committed Mar 23, 2021
1 parent 77b4230 commit 4157a07
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 5 deletions.
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
Expand Up @@ -16,6 +16,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
MLIRMath
MLIRMemRef
MLIRPass
MLIRTensor
MLIRTosa
MLIRTosaTransforms
MLIRSupport
Expand Down
80 changes: 77 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -1155,7 +1156,79 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultTy, genericOp.getResult(0),
rewriter.getI64ArrayAttr(resultTy.getShape()));
return success();
}
};

class PadConverter : public OpRewritePattern<tosa::PadOp> {
public:
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::PadOp padOp,
PatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.input1();
auto padding = padOp.padding();

ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType paddingTy = padding.getType().cast<ShapedType>();
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();

if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
return rewriter.notifyMatchFailure(
padOp,
"Pad converter requires static shaped input / padding values.");
}

Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
Value highIndex =
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));

SmallVector<OpFoldResult, 3> lowValues;
SmallVector<OpFoldResult, 3> highValues;

lowValues.reserve(rank);
highValues.reserve(rank);

for (int i = 0; i < rank; i++) {
Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, highIndex}));

lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
lowVal);
highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
highVal);

lowValues.push_back(lowVal);
highValues.push_back(highVal);
}

Attribute constantAttr;
if (elementTy.isa<FloatType>())
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
auto value = padOp.quantization_info().getValue().input_zp().getValue();
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
}

if (!constantAttr) {
return rewriter.notifyMatchFailure(
padOp,
"tosa.pad to linalg lowering encountered an unknown element type");
}

Value constant = rewriter.create<ConstantOp>(loc, constantAttr);

auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);

rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
Expand Down Expand Up @@ -1187,7 +1260,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
RescaleConverter, ReverseConverter, TileConverter, TransposeConverter,
MatMulConverter, FullyConnectedConverter>(patterns->getContext());
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter,
ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
TransposeConverter, MatMulConverter, FullyConnectedConverter>(
patterns->getContext());
}
6 changes: 4 additions & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
Expand All @@ -33,14 +34,15 @@ struct TosaToLinalgOnTensors
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, math::MathDialect,
memref::MemRefDialect, StandardOpsDialect>();
memref::MemRefDialect, StandardOpsDialect,
tensor::TensorDialect>();
}

void runOnFunction() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect>();
StandardOpsDialect, tensor::TensorDialect>();
target.addIllegalDialect<tosa::TosaDialect>();

// Not every TOSA op can be legalized to linalg.
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Expand Up @@ -702,3 +702,46 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: ten
%0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>)
return %0 : tensor<5x6xf32>
}

// -----

func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
%0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// CHECK: [[INDEX0:%.+]] = constant 0 : index
// CHECK: [[INDEX1:%.+]] = constant 1 : index
// CHECK: [[ROW0:%.+]] = constant 0 : index
// CHECK: [[LOW0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX0]]]
// CHECK: [[HIGH0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX1]]]
// CHECK: [[LOW0_IDX:%.+]] = index_cast %0
// CHECK: [[HIGH0_IDX:%.+]] = index_cast %1
// CHECK: [[ROW1:%.+]] = constant 1 : index
// CHECK: [[LOW1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c0]
// CHECK: [[HIGH1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c1]
// CHECK: [[LOW1_IDX:%.+]] = index_cast [[LOW1]]
// CHECK: [[HIGH1_IDX:%.+]] = index_cast [[HIGH1]]
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: %8 = linalg.pad_tensor %arg0 low{{\[}}[[LOW0_IDX]], [[LOW1_IDX]]] high{{\[}}[[HIGH0_IDX]], [[HIGH1_IDX]]] {
// CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors
// CHECK: linalg.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
return %1 : tensor<4x9xf32>
}

func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
%0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// CHECK: [[CST:%.+]] = constant 0 : i32
// CHECK: linalg.pad_tensor
// CHECK: linalg.yield [[CST]]
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}

func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
%0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// CHECK: [[CST:%.+]] = constant 42 : i32
// CHECK: linalg.pad_tensor
// CHECK: linalg.yield [[CST]]
%1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}

0 comments on commit 4157a07

Please sign in to comment.