Skip to content

Commit 4157a07

Browse files
committed
[mlir][tosa] Add tosa.pad to linalg.pad operation
Lowers from tosa's pad op to the linalg equivalent for floating, integer, and quantized values. Differential Revision: https://reviews.llvm.org/D98990
1 parent 77b4230 commit 4157a07

File tree

4 files changed

+125
-5
lines changed

4 files changed

+125
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
1616
MLIRMath
1717
MLIRMemRef
1818
MLIRPass
19+
MLIRTensor
1920
MLIRTosa
2021
MLIRTosaTransforms
2122
MLIRSupport

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Math/IR/Math.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/StandardOps/IR/Ops.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1819
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1920
#include "mlir/IR/Matchers.h"
2021
#include "mlir/IR/PatternMatch.h"
@@ -1155,7 +1156,79 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
11551156
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
11561157
op, resultTy, genericOp.getResult(0),
11571158
rewriter.getI64ArrayAttr(resultTy.getShape()));
1159+
return success();
1160+
}
1161+
};
1162+
1163+
class PadConverter : public OpRewritePattern<tosa::PadOp> {
1164+
public:
1165+
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
1166+
1167+
LogicalResult matchAndRewrite(tosa::PadOp padOp,
1168+
PatternRewriter &rewriter) const final {
1169+
auto loc = padOp.getLoc();
1170+
auto input = padOp.input1();
1171+
auto padding = padOp.padding();
1172+
1173+
ShapedType inputTy = input.getType().cast<ShapedType>();
1174+
ShapedType paddingTy = padding.getType().cast<ShapedType>();
1175+
Type elementTy = inputTy.getElementType();
1176+
int64_t rank = inputTy.getRank();
1177+
1178+
if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
1179+
return rewriter.notifyMatchFailure(
1180+
padOp,
1181+
"Pad converter requires static shaped input / padding values.");
1182+
}
1183+
1184+
Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
1185+
Value highIndex =
1186+
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
1187+
1188+
SmallVector<OpFoldResult, 3> lowValues;
1189+
SmallVector<OpFoldResult, 3> highValues;
1190+
1191+
lowValues.reserve(rank);
1192+
highValues.reserve(rank);
1193+
1194+
for (int i = 0; i < rank; i++) {
1195+
Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
1196+
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
1197+
loc, padding, ValueRange({inputIndex, lowIndex}));
1198+
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
1199+
loc, padding, ValueRange({inputIndex, highIndex}));
1200+
1201+
lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
1202+
lowVal);
1203+
highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
1204+
highVal);
1205+
1206+
lowValues.push_back(lowVal);
1207+
highValues.push_back(highVal);
1208+
}
1209+
1210+
Attribute constantAttr;
1211+
if (elementTy.isa<FloatType>())
1212+
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1213+
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
1214+
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1215+
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
1216+
auto value = padOp.quantization_info().getValue().input_zp().getValue();
1217+
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
1218+
}
1219+
1220+
if (!constantAttr) {
1221+
return rewriter.notifyMatchFailure(
1222+
padOp,
1223+
"tosa.pad to linalg lowering encountered an unknown element type");
1224+
}
1225+
1226+
Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
1227+
1228+
auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
1229+
padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
11581230

1231+
rewriter.replaceOp(padOp, newPadOp.getResult());
11591232
return success();
11601233
}
11611234
};
@@ -1187,7 +1260,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
11871260
IdentityNConverter<tosa::IdentityOp>,
11881261
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
11891262
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
1190-
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
1191-
RescaleConverter, ReverseConverter, TileConverter, TransposeConverter,
1192-
MatMulConverter, FullyConnectedConverter>(patterns->getContext());
1263+
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter,
1264+
ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
1265+
TransposeConverter, MatMulConverter, FullyConnectedConverter>(
1266+
patterns->getContext());
11931267
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Math/IR/Math.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/StandardOps/IR/Ops.h"
19+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1920
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
2021
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
2122
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -33,14 +34,15 @@ struct TosaToLinalgOnTensors
3334
public:
3435
void getDependentDialects(DialectRegistry &registry) const override {
3536
registry.insert<linalg::LinalgDialect, math::MathDialect,
36-
memref::MemRefDialect, StandardOpsDialect>();
37+
memref::MemRefDialect, StandardOpsDialect,
38+
tensor::TensorDialect>();
3739
}
3840

3941
void runOnFunction() override {
4042
RewritePatternSet patterns(&getContext());
4143
ConversionTarget target(getContext());
4244
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
43-
StandardOpsDialect>();
45+
StandardOpsDialect, tensor::TensorDialect>();
4446
target.addIllegalDialect<tosa::TosaDialect>();
4547

4648
// Not every TOSA op can be legalized to linalg.

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,3 +702,46 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: ten
702702
%0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>)
703703
return %0 : tensor<5x6xf32>
704704
}
705+
706+
// -----
707+
708+
func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
709+
%0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
710+
// CHECK: [[INDEX0:%.+]] = constant 0 : index
711+
// CHECK: [[INDEX1:%.+]] = constant 1 : index
712+
// CHECK: [[ROW0:%.+]] = constant 0 : index
713+
// CHECK: [[LOW0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX0]]]
714+
// CHECK: [[HIGH0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX1]]]
715+
// CHECK: [[LOW0_IDX:%.+]] = index_cast %0
716+
// CHECK: [[HIGH0_IDX:%.+]] = index_cast %1
717+
// CHECK: [[ROW1:%.+]] = constant 1 : index
718+
// CHECK: [[LOW1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c0]
719+
// CHECK: [[HIGH1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c1]
720+
// CHECK: [[LOW1_IDX:%.+]] = index_cast [[LOW1]]
721+
// CHECK: [[HIGH1_IDX:%.+]] = index_cast [[HIGH1]]
722+
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
723+
// CHECK: %8 = linalg.pad_tensor %arg0 low{{\[}}[[LOW0_IDX]], [[LOW1_IDX]]] high{{\[}}[[HIGH0_IDX]], [[HIGH1_IDX]]] {
724+
// CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors
725+
// CHECK: linalg.yield [[CST]]
726+
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
727+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
728+
return %1 : tensor<4x9xf32>
729+
}
730+
731+
func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
732+
%0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
733+
// CHECK: [[CST:%.+]] = constant 0 : i32
734+
// CHECK: linalg.pad_tensor
735+
// CHECK: linalg.yield [[CST]]
736+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
737+
return %1 : tensor<4x9xi32>
738+
}
739+
740+
func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
741+
%0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
742+
// CHECK: [[CST:%.+]] = constant 42 : i32
743+
// CHECK: linalg.pad_tensor
744+
// CHECK: linalg.yield [[CST]]
745+
%1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
746+
return %1 : tensor<4x9xi32>
747+
}

0 commit comments

Comments
 (0)