diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt index a44621ec603329..7cc1721bb0fb8a 100644 --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg MLIRMath MLIRMemRef MLIRPass + MLIRTensor MLIRTosa MLIRTosaTransforms MLIRSupport diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 12e9e694760c22..a4b6f826feb6c0 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -1155,7 +1156,79 @@ struct TileConverter : public OpConversionPattern { rewriter.replaceOpWithNewOp( op, resultTy, genericOp.getResult(0), rewriter.getI64ArrayAttr(resultTy.getShape())); + return success(); + } +}; + +class PadConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::PadOp padOp, + PatternRewriter &rewriter) const final { + auto loc = padOp.getLoc(); + auto input = padOp.input1(); + auto padding = padOp.padding(); + + ShapedType inputTy = input.getType().cast(); + ShapedType paddingTy = padding.getType().cast(); + Type elementTy = inputTy.getElementType(); + int64_t rank = inputTy.getRank(); + + if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + padOp, + "Pad converter requires static shaped input / padding values."); + } + + Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value highIndex = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + SmallVector lowValues; + SmallVector highValues; + + lowValues.reserve(rank); + highValues.reserve(rank); + + for (int i = 0; i < rank; i++) { + Value inputIndex = rewriter.createOrFold(loc, i); + Value lowVal = rewriter.createOrFold( + loc, padding, ValueRange({inputIndex, lowIndex})); + Value highVal = rewriter.createOrFold( + loc, padding, ValueRange({inputIndex, highIndex})); + + lowVal = rewriter.createOrFold(loc, rewriter.getIndexType(), + lowVal); + highVal = rewriter.createOrFold(loc, rewriter.getIndexType(), + highVal); + + lowValues.push_back(lowVal); + highValues.push_back(highVal); + } + + Attribute constantAttr; + if (elementTy.isa()) + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + else if (elementTy.isa() && !padOp.quantization_info()) + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + else if (elementTy.isa() && padOp.quantization_info()) { + auto value = padOp.quantization_info().getValue().input_zp().getValue(); + constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + } + + if (!constantAttr) { + return rewriter.notifyMatchFailure( + padOp, + "tosa.pad to linalg lowering encountered an unknown element type"); + } + + Value constant = rewriter.create(loc, constantAttr); + + auto newPadOp = linalg::PadTensorOp::createPadScalarOp( + padOp.getType(), input, constant, lowValues, highValues, loc, rewriter); + rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); } }; @@ -1187,7 +1260,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ConcatConverter, ReshapeConverter, - RescaleConverter, ReverseConverter, TileConverter, TransposeConverter, - MatMulConverter, FullyConnectedConverter>(patterns->getContext()); + ReduceConverter, ConcatConverter, PadConverter, + ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter, + TransposeConverter, MatMulConverter, FullyConnectedConverter>( + patterns->getContext()); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 5c0dbc50c2d75e..baf9e575a47368 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" @@ -33,14 +34,15 @@ struct TosaToLinalgOnTensors public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + memref::MemRefDialect, StandardOpsDialect, + tensor::TensorDialect>(); } void runOnFunction() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); + StandardOpsDialect, tensor::TensorDialect>(); target.addIllegalDialect(); // Not every TOSA op can be legalized to linalg. diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 018e9e4d7e5401..39a4f4122924ce 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -702,3 +702,46 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: ten %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>) return %0 : tensor<5x6xf32> } + +// ----- + +func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[INDEX0:%.+]] = constant 0 : index + // CHECK: [[INDEX1:%.+]] = constant 1 : index + // CHECK: [[ROW0:%.+]] = constant 0 : index + // CHECK: [[LOW0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX0]]] + // CHECK: [[HIGH0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX1]]] + // CHECK: [[LOW0_IDX:%.+]] = index_cast %0 + // CHECK: [[HIGH0_IDX:%.+]] = index_cast %1 + // CHECK: [[ROW1:%.+]] = constant 1 : index + // CHECK: [[LOW1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c0] + // CHECK: [[HIGH1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c1] + // CHECK: [[LOW1_IDX:%.+]] = index_cast [[LOW1]] + // CHECK: [[HIGH1_IDX:%.+]] = index_cast [[HIGH1]] + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: %8 = linalg.pad_tensor %arg0 low{{\[}}[[LOW0_IDX]], [[LOW1_IDX]]] high{{\[}}[[HIGH0_IDX]], [[HIGH1_IDX]]] { + // CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors + // CHECK: linalg.yield [[CST]] + // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) + return %1 : tensor<4x9xf32> +} + +func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[CST:%.+]] = constant 0 : i32 + // CHECK: linalg.pad_tensor + // CHECK: linalg.yield [[CST]] + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + return %1 : tensor<4x9xi32> +} + +func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { + %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + // CHECK: [[CST:%.+]] = constant 42 : i32 + // CHECK: linalg.pad_tensor + // CHECK: linalg.yield [[CST]] + %1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + return %1 : tensor<4x9xi32> +}