diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index a5b42990625af..711c101b15a52 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -34,6 +34,17 @@ namespace tosa { } // namespace tosa } // namespace mlir +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// +namespace mlir { +namespace tosa { +/// Appends the canonicalization patterns for all the TOSA ops to the `patterns` +void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); +} // namespace tosa +} // namespace mlir + #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 1bdfc2f43bf3b..9ffccfc948824 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -26,7 +26,10 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, + RewritePatternSet &patterns); +std::unique_ptr createTosaLayerwiseConstantFoldPass(); std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index c3180ec14a325..46bd7a4780e00 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -15,6 +15,15 @@ include "mlir/Pass/PassBase.td" +def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> { + let summary = "Fold layerwise operations on constant tensors"; + let description = [{ + Pass that enables folding of full-layer operations on constant tensors. + }]; + + let constructor = "createTosaLayerwiseConstantFoldPass()"; +} + def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> { let summary = "Propagate shapes across TOSA operations"; let description = [{ diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index a8c610c05a7bc..18f7efe36f503 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -76,6 +76,8 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm, pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addNestedPass(tosa::createTosaToLinalgNamed()); pm.addNestedPass(createCanonicalizerPass()); + // TODO: Remove pass that operates on const tensor and enable optionality + pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1cf2a8808a07f..4de0c0f1a9ed8 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -94,6 +94,20 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, // Operator Canonicalizers. //===----------------------------------------------------------------------===// +template +void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { + (void)std::initializer_list{ + 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...}; +} + +void mlir::tosa::populateTosaOpsCanonicalizationPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + addOpsCanonicalizations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(ctx, patterns); +} + struct ConcatOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -189,70 +203,6 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { return success(); } -struct ConstantTransposeOptimization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::TransposeOp op, - PatternRewriter &rewriter) const override { - auto outputType = op.getType().cast(); - ArrayRef outputShape = outputType.getShape(); - // TOSA supports quantized types. - if (!outputType.getElementType().isIntOrIndexOrFloat()) - return failure(); - - DenseElementsAttr inputValues; - if (!matchPattern(op.input1(), m_Constant(&inputValues))) - return failure(); - // Make sure the input is a constant that has a single user. - if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers())) - return failure(); - - DenseIntElementsAttr permAttr; - if (!matchPattern(op.perms(), m_Constant(&permAttr))) - return failure(); - auto permValues = llvm::to_vector<6>(llvm::map_range( - // TOSA allows both 32- and 64-bit integer tensors here. - permAttr.getValues(), - [](const APInt &val) { return val.getZExtValue(); })); - - auto inputType = op.input1().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - int64_t numElements = inputType.getNumElements(); - - SmallVector outputValues; - outputValues.resize(numElements); - - // Transpose the input constant. Because we don't know its rank in advance, - // we need to loop over the range [0, element count) and delinearize the - // index. - auto attrValues = inputValues.getValues(); - for (int srcLinearIndex = 0; srcLinearIndex < numElements; - ++srcLinearIndex) { - SmallVector srcIndices(inputType.getRank(), 0); - int totalCount = srcLinearIndex; - for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { - srcIndices[dim] = totalCount % inputShape[dim]; - totalCount /= inputShape[dim]; - } - - SmallVector dstIndices(outputType.getRank(), 0); - for (int dim = outputType.getRank() - 1; dim >= 0; --dim) - dstIndices[dim] = srcIndices[permValues[dim]]; - - uint64_t dstLinearIndex = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - - outputValues[dstLinearIndex] = attrValues[srcIndices]; - } - - rewriter.replaceOpWithNewOp( - op, outputType, DenseElementsAttr::get(outputType, outputValues)); - return success(); - } -}; - struct NoOpOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -282,7 +232,6 @@ struct NoOpOptimization : public OpRewritePattern { void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); results.add(context); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index e98d3dfe26a70..79979eee9077d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -2,7 +2,9 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp + TosaFoldConstantTranspose.cpp TosaInferShapes.cpp + TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index ac8583f7c03e2..ef94e55c855d3 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -1,4 +1,4 @@ -//===- TosaDecomposeConv2D.cpp ------------------------------------------===// +//===- TosaDecomposeConv2D.cpp --------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 2ce9f24e6d9c9..b4bac42029e49 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -1,5 +1,4 @@ -//===- TosaDecomposeDepthwise.cpp -//------------------------------------------===// +//===- TosaDecomposeDepthwise.cpp -----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index d6ffa463f31bd..fa6ec91bb2416 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -1,5 +1,4 @@ -//===- TosaDecomposeTransposeConv.cpp -//------------------------------------------===// +//===- TosaDecomposeTransposeConv.cpp -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp new file mode 100644 index 0000000000000..5f14cf68321af --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -0,0 +1,91 @@ +//===- TosaFoldConstantTranspose.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Transpose operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto outputType = op.getType().cast(); + // TOSA supports quantized types. + if (!outputType.getElementType().isIntOrIndexOrFloat()) + return failure(); + + DenseElementsAttr inputValues; + if (!matchPattern(op.input1(), m_Constant(&inputValues))) + return failure(); + // Make sure the input is a constant that has a single user. + if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers())) + return failure(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(op.perms(), m_Constant(&permAttr))) + return failure(); + auto permValues = llvm::to_vector<6>(llvm::map_range( + // TOSA allows both 32- and 64-bit integer tensors here. + permAttr.getValues(), + [](const APInt &val) { return val.getZExtValue(); })); + + auto inputType = op.input1().getType().cast(); + ArrayRef inputShape = inputType.getShape(); + int64_t numElements = inputType.getNumElements(); + + SmallVector outputValues; + outputValues.resize(numElements); + + // Transpose the input constant. Because we don't know its rank in advance, + // we need to loop over the range [0, element count) and delinearize the + // index. + auto attrValues = inputValues.getValues(); + ArrayRef outputShape = outputType.getShape(); + for (int srcLinearIndex = 0; srcLinearIndex < numElements; + ++srcLinearIndex) { + SmallVector srcIndices(inputType.getRank(), 0); + int totalCount = srcLinearIndex; + for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { + srcIndices[dim] = totalCount % inputShape[dim]; + totalCount /= inputShape[dim]; + } + + SmallVector dstIndices(outputType.getRank(), 0); + for (int dim = outputType.getRank() - 1; dim >= 0; --dim) + dstIndices[dim] = srcIndices[permValues[dim]]; + + uint64_t dstLinearIndex = dstIndices.front(); + for (int dim = 1; dim < outputType.getRank(); ++dim) + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + + outputValues[dstLinearIndex] = attrValues[srcIndices]; + } + + rewriter.replaceOpWithNewOp( + op, outputType, DenseElementsAttr::get(outputType, outputValues)); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantTransposePatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp index fc55e44a7d373..e75399b7bdd24 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -1,4 +1,4 @@ -//===- TosaInferShapes.cpp ------------------------------------------===// +//===- TosaInferShapes.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp new file mode 100644 index 0000000000000..7cf7ff14eb9ac --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -0,0 +1,43 @@ +//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements constant folding transformations on TOSA operations +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaLayerwiseConstantFoldPass + : public TosaLayerwiseConstantFoldPassBase { + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto func = getOperation(); + + mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); + mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns); + + if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::tosa::createTosaLayerwiseConstantFoldPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp index 0bf9eed621107..78b8cb3084afd 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp @@ -1,5 +1,4 @@ -//===- TosaOptionalDecompositions.cpp -//------------------------------------------===// +//===- TosaOptionalDecompositions.cpp -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 934ca583f330d..62f1adb1e77ac 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -391,104 +391,6 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { return %0 : tensor<3x8xf32> } -// CHECK-LABEL: @transpose_fold -func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { - // CHECK: return %arg0 - %0 = arith.constant dense<[0, 1]> : tensor<2xi32> - %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32> - return %1 : tensor<3x4xf32> -} - -// CHECK-LABEL: @transpose_nofold -func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - // CHECK: "tosa.transpose" - %0 = arith.constant dense<[1, 0]> : tensor<2xi32> - %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - return %1 : tensor<3x3xf32> -} - -// CHECK-LABEL: @transpose_nofold_shape -func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { - // CHECK: "tosa.transpose" - %0 = arith.constant dense<[1, 0]> : tensor<2xi32> - %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor - return %1 : tensor -} - -// CHECK-LABEL: @transpose_fold_splat -func.func @transpose_fold_splat() -> tensor<3x2xf32> { - %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: %[[CST:.+]] = "tosa.const"() - // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32> - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - // CHECK: return %[[CST]] - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_fold_2d_float -func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { - %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: %[[CST:.+]] = "tosa.const"() - // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - // CHECK: return %[[CST]] - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_fold_4d_int -func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { - %input = "tosa.const"() {value = dense<[[ - [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], - [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] - ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32> - %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> - // CHECK: %[[CST:.+]] = "tosa.const"() - // CHECK-SAME{LITERAL}: value = dense<[ - // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], - // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], - // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] - // CHECK-SAME{LITERAL}: ]> - %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32> - // CHECK: return %[[CST]] - return %1 : tensor<3x1x4x2xi32> -} - -// CHECK-LABEL: @transpose_nofold_non_cst_input -func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: tosa.transpose - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_nofold_non_cst_perms -func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> { - %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - // CHECK: tosa.transpose - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_nofold_multi_users -func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { - %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: tosa.transpose - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> -} - -// CHECK-LABEL: @transpose_nofold_quantized_types -func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> { - %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> - %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8> - // CHECK: tosa.transpose - %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> - return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> -} - // CHECK-LABEL: @transpose_no_op func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir new file mode 100644 index 0000000000000..09f8245e771a7 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @transpose_fold +func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: return %arg0 + %0 = arith.constant dense<[0, 1]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + +// CHECK-LABEL: @transpose_nofold +func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK: "tosa.transpose" + %0 = arith.constant dense<[1, 0]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %1 : tensor<3x3xf32> +} + +// CHECK-LABEL: @transpose_nofold_shape +func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { + // CHECK: "tosa.transpose" + %0 = arith.constant dense<[1, 0]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @transpose_fold_splat +func.func @transpose_fold_splat() -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_fold_2d_float +func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_fold_4d_int +func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { + %input = "tosa.const"() {value = dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32> + %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_nofold_non_cst_perms +func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_nofold_multi_users +func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> +} + +// CHECK-LABEL: @transpose_nofold_quantized_types +func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> { + %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> + %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8> + // CHECK: tosa.transpose + %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> + return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> +}