|
15 | 15 | #include "mlir/Dialect/Math/IR/Math.h"
|
16 | 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
17 | 17 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
| 18 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | 19 | #include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
19 | 20 | #include "mlir/IR/Matchers.h"
|
20 | 21 | #include "mlir/IR/PatternMatch.h"
|
@@ -1155,7 +1156,79 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
|
1155 | 1156 | rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
1156 | 1157 | op, resultTy, genericOp.getResult(0),
|
1157 | 1158 | rewriter.getI64ArrayAttr(resultTy.getShape()));
|
| 1159 | + return success(); |
| 1160 | + } |
| 1161 | +}; |
| 1162 | + |
| 1163 | +class PadConverter : public OpRewritePattern<tosa::PadOp> { |
| 1164 | +public: |
| 1165 | + using OpRewritePattern<tosa::PadOp>::OpRewritePattern; |
| 1166 | + |
| 1167 | + LogicalResult matchAndRewrite(tosa::PadOp padOp, |
| 1168 | + PatternRewriter &rewriter) const final { |
| 1169 | + auto loc = padOp.getLoc(); |
| 1170 | + auto input = padOp.input1(); |
| 1171 | + auto padding = padOp.padding(); |
| 1172 | + |
| 1173 | + ShapedType inputTy = input.getType().cast<ShapedType>(); |
| 1174 | + ShapedType paddingTy = padding.getType().cast<ShapedType>(); |
| 1175 | + Type elementTy = inputTy.getElementType(); |
| 1176 | + int64_t rank = inputTy.getRank(); |
| 1177 | + |
| 1178 | + if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { |
| 1179 | + return rewriter.notifyMatchFailure( |
| 1180 | + padOp, |
| 1181 | + "Pad converter requires static shaped input / padding values."); |
| 1182 | + } |
| 1183 | + |
| 1184 | + Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0)); |
| 1185 | + Value highIndex = |
| 1186 | + rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1)); |
| 1187 | + |
| 1188 | + SmallVector<OpFoldResult, 3> lowValues; |
| 1189 | + SmallVector<OpFoldResult, 3> highValues; |
| 1190 | + |
| 1191 | + lowValues.reserve(rank); |
| 1192 | + highValues.reserve(rank); |
| 1193 | + |
| 1194 | + for (int i = 0; i < rank; i++) { |
| 1195 | + Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i); |
| 1196 | + Value lowVal = rewriter.createOrFold<tensor::ExtractOp>( |
| 1197 | + loc, padding, ValueRange({inputIndex, lowIndex})); |
| 1198 | + Value highVal = rewriter.createOrFold<tensor::ExtractOp>( |
| 1199 | + loc, padding, ValueRange({inputIndex, highIndex})); |
| 1200 | + |
| 1201 | + lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(), |
| 1202 | + lowVal); |
| 1203 | + highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(), |
| 1204 | + highVal); |
| 1205 | + |
| 1206 | + lowValues.push_back(lowVal); |
| 1207 | + highValues.push_back(highVal); |
| 1208 | + } |
| 1209 | + |
| 1210 | + Attribute constantAttr; |
| 1211 | + if (elementTy.isa<FloatType>()) |
| 1212 | + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); |
| 1213 | + else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) |
| 1214 | + constantAttr = rewriter.getIntegerAttr(elementTy, 0); |
| 1215 | + else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) { |
| 1216 | + auto value = padOp.quantization_info().getValue().input_zp().getValue(); |
| 1217 | + constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); |
| 1218 | + } |
| 1219 | + |
| 1220 | + if (!constantAttr) { |
| 1221 | + return rewriter.notifyMatchFailure( |
| 1222 | + padOp, |
| 1223 | + "tosa.pad to linalg lowering encountered an unknown element type"); |
| 1224 | + } |
| 1225 | + |
| 1226 | + Value constant = rewriter.create<ConstantOp>(loc, constantAttr); |
| 1227 | + |
| 1228 | + auto newPadOp = linalg::PadTensorOp::createPadScalarOp( |
| 1229 | + padOp.getType(), input, constant, lowValues, highValues, loc, rewriter); |
1158 | 1230 |
|
| 1231 | + rewriter.replaceOp(padOp, newPadOp.getResult()); |
1159 | 1232 | return success();
|
1160 | 1233 | }
|
1161 | 1234 | };
|
@@ -1187,7 +1260,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
1187 | 1260 | IdentityNConverter<tosa::IdentityOp>,
|
1188 | 1261 | IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
|
1189 | 1262 | ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
|
1190 |
| - ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter, |
1191 |
| - RescaleConverter, ReverseConverter, TileConverter, TransposeConverter, |
1192 |
| - MatMulConverter, FullyConnectedConverter>(patterns->getContext()); |
| 1263 | + ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter, |
| 1264 | + ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter, |
| 1265 | + TransposeConverter, MatMulConverter, FullyConnectedConverter>( |
| 1266 | + patterns->getContext()); |
1193 | 1267 | }
|
0 commit comments