From dfd070820cbae9d6864a7de20ae81757d199fc61 Mon Sep 17 00:00:00 2001 From: Aaron DeBattista Date: Tue, 11 Jan 2022 10:16:01 -0800 Subject: [PATCH] [mlir][tosa] Allow optional TOSA decompositions to be populated separately Moved all TOSA decomposition patterns so that they can be optionally populated and used by external rewrites. This avoids decomposing TOSa operations when backends may benefit from the non-decomposed version. Reviewed By: rsuderman, mehdi_amini Differential Revision: https://reviews.llvm.org/D116526 --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 11 +- .../mlir/Dialect/Tosa/Transforms/Passes.td | 24 +- .../TosaToLinalg/TosaToLinalgPass.cpp | 1 + .../Dialect/Tosa/Transforms/CMakeLists.txt | 4 +- .../Tosa/Transforms/TosaDecomposeConv2D.cpp | 115 +++++++++ .../Transforms/TosaDecomposeDepthwise.cpp | 121 +++++++++ .../Transforms/TosaDecomposeTransposeConv.cpp | 32 +-- .../Tosa/Transforms/TosaOptimization.cpp | 243 ------------------ .../Transforms/TosaOptionalDecompositions.cpp | 46 ++++ ...zation.mlir => tosa-decompose-conv2d.mlir} | 109 +++----- .../Tosa/tosa-decompose-depthwise.mlir | 32 +++ .../Tosa/tosa-decompose-transpose-conv.mlir | 2 +- 12 files changed, 384 insertions(+), 356 deletions(-) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp rename mlir/test/Dialect/Tosa/{operation_optimization.mlir => tosa-decompose-conv2d.mlir} (53%) create mode 100644 mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index e94daca0df728..1bdfc2f43bf3b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -19,11 +19,18 @@ namespace mlir { namespace tosa { -std::unique_ptr createTosaDecomposeTransposeConvPass(); +// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops. +// The rewrites can be selectively added to a conversion pass. +void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaDecomposeTransposeConv(MLIRContext *ctx, + RewritePatternSet &patterns); +void populateTosaDecomposeDepthwise(MLIRContext *ctx, + RewritePatternSet &patterns); + std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); -std::unique_ptr createTosaOptimizationPass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); +std::unique_ptr createTosaOptionalDecompositions(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 4a75482ba832a..fbb3134e00411 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -15,21 +15,6 @@ include "mlir/Pass/PassBase.td" -def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> { - let summary = "Deompose transpose convolutiions into standard convolutions."; - let description = [{ - Pass that uses shape manipulation and convolution operations to transform - a transpose convolution into a regular convolution. - }]; - - let constructor = "createTosaDecomposeTransposeConvPass()"; - let dependentDialects = [ - "StandardOpsDialect", - "tensor::TensorDialect", - "tosa::TosaDialect", - ]; -} - def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> { let summary = "Propagate shapes across TOSA operations"; let description = [{ @@ -58,13 +43,14 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> { let constructor = "createTosaMakeBroadcastablePass()"; } -def TosaOptimization : FunctionPass<"tosa-optimization"> { - let summary = "TOSA operation optimizations"; +def TosaOptionalDecompositions : FunctionPass<"tosa-optional-decompositions"> { + let summary = "Applies Tosa operations optional decompositions"; let description = [{ - "Pass to perform optimizations on TOSA operations" + Pass to apply the Tosa operations decompositions + exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h }]; - let constructor = "createTosaOptimizationPass()"; + let constructor = "tosa::createTosaOptionalDecompositions()"; } #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 3813ba3451374..e75e8d72bc2ea 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -68,6 +68,7 @@ std::unique_ptr mlir::tosa::createTosaToLinalg() { } void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm) { + pm.addNestedPass(mlir::tosa::createTosaOptionalDecompositions()); pm.addNestedPass(createTosaMakeBroadcastablePass()); pm.addNestedPass(createTosaToLinalgNamed()); pm.addNestedPass(mlir::createCanonicalizerPass()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 016575fc7735b..a24d9cb65cbd2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,8 +1,10 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp + TosaDecomposeConv2D.cpp + TosaDecomposeDepthwise.cpp TosaInferShapes.cpp TosaMakeBroadcastable.cpp - TosaOptimization.cpp + TosaOptionalDecompositions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp new file mode 100644 index 0000000000000..4c412f987899e --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -0,0 +1,115 @@ +//===- TosaDecomposeConv2D.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically +// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct Conv2DIsFullyConnected : public OpRewritePattern { + explicit Conv2DIsFullyConnected(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tosa::Conv2DOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.getType().cast(); + + if (!inputType.hasStaticShape() || !weightType.hasRank()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[1] != 1 || weightShape[2] != 1) { + return failure(); + } + + // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. + llvm::SmallVector revisedWeightShape{weightShape[0], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform a fully connected network over the reshaped input and weight. + llvm::SmallVector fullyConnectedShape{ + inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; + auto fullyConnectedShapeType = RankedTensorType::get( + fullyConnectedShape, + resultType.dyn_cast().getElementType()); + + Value fullyConnectedValue; + if (op.quantization_info()) { + fullyConnectedValue = + rewriter + .create( + op.getLoc(), fullyConnectedShapeType, reshapedInput, + reshapedWeight, op.bias(), op.quantization_info().getValue()) + .getResult(); + } else { + fullyConnectedValue = rewriter + .create( + op.getLoc(), fullyConnectedShapeType, + reshapedInput, reshapedWeight, op.bias()) + .getResult(); + } + + // Reshape output to [N, IH, IW, OC]. + llvm::SmallVector outputShape{inputShape[0], inputShape[1], + inputShape[2], weightShape[0]}; + rewriter.replaceOpWithNewOp( + op, resultType, fullyConnectedValue, + rewriter.getI64ArrayAttr(outputShape)); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, + RewritePatternSet &patterns) { + patterns.insert(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp new file mode 100644 index 0000000000000..685f97353d746 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -0,0 +1,121 @@ +//===- TosaDecomposeDepthwise.cpp +//------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically +// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct DepthwiseConv2DIsMul : public OpRewritePattern { + explicit DepthwiseConv2DIsMul(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.output().getType().cast(); + Type inputEType = inputType.getElementType(); + + if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && + resultType.hasStaticShape())) { + return failure(); + } + + // Quantization information needs to still be performed. + if (op.quantization_info() || !inputEType.isa()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[0] != 1 || weightShape[1] != 1) { + return failure(); + } + + // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. + llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform an elementwise mul over the reshaped input and weight. + llvm::SmallVector mulShape{inputShape[0], inputShape[1], + inputShape[2], inputShape[3], + weightShape[3]}; + auto mulShapeType = RankedTensorType::get( + mulShape, + weight.getType().dyn_cast().getElementType()); + Value mulValue = + rewriter + .create(op.getLoc(), mulShapeType, reshapedInput, + reshapedWeight, /*shift=*/0) + .getResult(); + + // Reshape output to [N, H, W, C * M]. + auto outputShape = op.output().getType().cast().getShape(); + auto outputShapeType = RankedTensorType::get( + outputShape, + input.getType().dyn_cast().getElementType()); + auto outputValue = + rewriter.create(op.getLoc(), outputShapeType, mulValue, + rewriter.getI64ArrayAttr(outputShape)); + + // Add in the bias. + rewriter + .replaceOpWithNewOp(op, outputShapeType, outputValue, + op.bias()) + .getResult(); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, + RewritePatternSet &patterns) { + patterns.insert(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 341e78d527925..330add9e248ea 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -7,17 +7,19 @@ // //===----------------------------------------------------------------------===// // -// Insert reshape to binary op's input if needed to match rank +// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically +// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping +// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D +// including transposing/reversing/reshaping etc.. +// of the weights and input/output tenors and reversing/reshaping etc .. of +// the weights // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR//TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::tosa; @@ -369,22 +371,10 @@ class TransposeConvStridedConverter } }; -/// Pass that enables broadcast by making all input arrays have the same -/// number of dimensions. Insert RESHAPE operations to lower rank operand -struct TosaDecomposeTransposeConv - : public TosaDecomposeTransposeConvBase { -public: - void runOnFunction() override { - auto func = getFunction(); - RewritePatternSet patterns(func.getContext()); - patterns - .insert( - func.getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - } -}; } // namespace -std::unique_ptr mlir::tosa::createTosaDecomposeTransposeConvPass() { - return std::make_unique(); +void mlir::tosa::populateTosaDecomposeTransposeConv( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp deleted file mode 100644 index 9a19b63ed1983..0000000000000 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp +++ /dev/null @@ -1,243 +0,0 @@ -//===- TosaOptimization.cpp ------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Pass to perform optimizations on TOSA operations -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/DataFlowAnalysis.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/FormatVariadic.h" - -using namespace mlir; -using namespace mlir::tosa; - -#define PASS_NAME "tosa-optimization" -#define DEBUG_TYPE PASS_NAME - -namespace { - -struct Conv2DIsFullyConnected : public OpRewritePattern { - explicit Conv2DIsFullyConnected(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(tosa::Conv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); - - if (!inputType.hasStaticShape() || !weightType.hasRank()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[1] != 1 || weightShape[2] != 1) { - return failure(); - } - - // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. - llvm::SmallVector revisedWeightShape{weightShape[0], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform a fully connected network over the reshaped input and weight. - llvm::SmallVector fullyConnectedShape{ - inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; - auto fullyConnectedShapeType = RankedTensorType::get( - fullyConnectedShape, - resultType.dyn_cast().getElementType()); - - Value fullyConnectedValue; - if (op.quantization_info()) { - fullyConnectedValue = - rewriter - .create( - op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.bias(), op.quantization_info().getValue()) - .getResult(); - } else { - fullyConnectedValue = rewriter - .create( - op.getLoc(), fullyConnectedShapeType, - reshapedInput, reshapedWeight, op.bias()) - .getResult(); - } - - // Reshape output to [N, IH, IW, OC]. - llvm::SmallVector outputShape{inputShape[0], inputShape[1], - inputShape[2], weightShape[0]}; - rewriter.replaceOpWithNewOp( - op, resultType, fullyConnectedValue, - rewriter.getI64ArrayAttr(outputShape)); - return success(); - } -}; - -struct DepthwiseConv2DIsMul : public OpRewritePattern { - explicit DepthwiseConv2DIsMul(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.output().getType().cast(); - Type inputEType = inputType.getElementType(); - - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { - return failure(); - } - - // Quantization information needs to still be performed. - if (op.quantization_info() || !inputEType.isa()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[0] != 1 || weightShape[1] != 1) { - return failure(); - } - - // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. - llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform an elementwise mul over the reshaped input and weight. - llvm::SmallVector mulShape{inputShape[0], inputShape[1], - inputShape[2], inputShape[3], - weightShape[3]}; - auto mulShapeType = RankedTensorType::get( - mulShape, - weight.getType().dyn_cast().getElementType()); - Value mulValue = - rewriter - .create(op.getLoc(), mulShapeType, reshapedInput, - reshapedWeight, /*shift=*/0) - .getResult(); - - // Reshape output to [N, H, W, C * M]. - auto outputShape = op.output().getType().cast().getShape(); - auto outputShapeType = RankedTensorType::get( - outputShape, - input.getType().dyn_cast().getElementType()); - auto outputValue = - rewriter.create(op.getLoc(), outputShapeType, mulValue, - rewriter.getI64ArrayAttr(outputShape)); - - // Add in the bias. - rewriter - .replaceOpWithNewOp(op, outputShapeType, outputValue, - op.bias()) - .getResult(); - return success(); - } -}; - -class TosaOptimization : public PassWrapper { -public: - explicit TosaOptimization() = default; - void runOnFunction() override; - - StringRef getArgument() const final { return PASS_NAME; } - StringRef getDescription() const final { - return "Applies TOSA Operation Optimizations"; - } -}; - -void TosaOptimization::runOnFunction() { - OwningRewritePatternList patterns(&getContext()); - - patterns.insert(&getContext()); - patterns.insert(&getContext()); - - auto func = getFunction(); - if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) { - signalPassFailure(); - } -} - -} // namespace - -std::unique_ptr mlir::tosa::createTosaOptimizationPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp new file mode 100644 index 0000000000000..50fd635c8a46a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp @@ -0,0 +1,46 @@ +//===- TosaOptionalDecompositions.cpp +//------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pass to apply the Tosa operations decompositions +// exposed as populate functions in +// include/mlir/Dialect/Tosa/Transforms/Passes.h +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TosaOptionalDecompositions + : public TosaOptionalDecompositionsBase { + void runOnFunction() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto func = getFunction(); + + mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns); + mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); + mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); + + if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::tosa::createTosaOptionalDecompositions() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir similarity index 53% rename from mlir/test/Dialect/Tosa/operation_optimization.mlir rename to mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index aa65b96bad4ef..cd9864f0c04f7 100644 --- a/mlir/test/Dialect/Tosa/operation_optimization.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -1,69 +1,40 @@ -// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected -func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK-SAME: -> tensor<400x2xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK-SAME: -> tensor<3x2xf32> - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: -> tensor<400x3xf32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK-SAME: -> tensor<4x10x10x3xf32> - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> - return %0 : tensor<4x10x10x3xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected_quant -func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK-SAME: -> tensor<400x2xi8> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK-SAME: -> tensor<3x2xi8> - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} - // CHECK-SAME: -> tensor<400x3xi32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK-SAME: -> tensor<4x10x10x3xi32> - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> - return %0 : tensor<4x10x10x3xi32> -} - -// ----- - -// CHECK-LABEL: @depthwise_conv2d_as_mul -func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { - // CHECK-NOT: "tosa.depthwise_conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} - // CHECK-SAME: -> tensor<4x10x10x2x1xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} - // CHECK-SAME: -> tensor<1x1x1x2x3xf32> - // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) - // CHECK-SAME: -> tensor<4x10x10x2x3xf32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} - // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) - // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: return %[[VAR4]] - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> - return %0 : tensor<4x10x10x6xf32> -} - -// ----- - -// CHECK-LABEL: @depthwise_conv2d_as_mul_q -func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { - // CHECK: "tosa.depthwise_conv2d" - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> - return %0 : tensor<4x10x10x6xi32> -} - -// ----- +// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected +func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xf32> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: -> tensor<400x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xf32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + return %0 : tensor<4x10x10x3xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected_quant +func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xi8> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xi8> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} + // CHECK-SAME: -> tensor<400x3xi32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xi32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + return %0 : tensor<4x10x10x3xi32> +} + +// ----- diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir new file mode 100644 index 0000000000000..e6370d7a8314b --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul +func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { + // CHECK-NOT: "tosa.depthwise_conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} + // CHECK-SAME: -> tensor<4x10x10x2x1xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} + // CHECK-SAME: -> tensor<1x1x1x2x3xf32> + // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) + // CHECK-SAME: -> tensor<4x10x10x2x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: return %[[VAR4]] + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + return %0 : tensor<4x10x10x6xf32> +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul_q +func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + return %0 : tensor<4x10x10x6xi32> +} + +// ----- diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 627622ba796e3..d0e9e5e17e84f 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s +// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s // CHECK-LABEL: @transpose_conv2d func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {