Skip to content

Commit

Permalink
[mlir][tosa] Allow optional TOSA decompositions to be populated separ…
Browse files Browse the repository at this point in the history
…ately

Moved all TOSA decomposition patterns so that they can be optionally populated
and used by external rewrites. This avoids decomposing TOSa operations when
backends may benefit from the non-decomposed version.

Reviewed By: rsuderman, mehdi_amini

Differential Revision: https://reviews.llvm.org/D116526
  • Loading branch information
aardeb authored and rsuderman committed Jan 11, 2022
1 parent 0a8d15a commit dfd0708
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 356 deletions.
11 changes: 9 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Expand Up @@ -19,11 +19,18 @@
namespace mlir {
namespace tosa {

std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
// The rewrites can be selectively added to a conversion pass.
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);

std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaOptimizationPass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
std::unique_ptr<Pass> createTosaOptionalDecompositions();

#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
Expand Down
24 changes: 5 additions & 19 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Expand Up @@ -15,21 +15,6 @@

include "mlir/Pass/PassBase.td"

def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> {
let summary = "Deompose transpose convolutiions into standard convolutions.";
let description = [{
Pass that uses shape manipulation and convolution operations to transform
a transpose convolution into a regular convolution.
}];

let constructor = "createTosaDecomposeTransposeConvPass()";
let dependentDialects = [
"StandardOpsDialect",
"tensor::TensorDialect",
"tosa::TosaDialect",
];
}

def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
let summary = "Propagate shapes across TOSA operations";
let description = [{
Expand Down Expand Up @@ -58,13 +43,14 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
let constructor = "createTosaMakeBroadcastablePass()";
}

def TosaOptimization : FunctionPass<"tosa-optimization"> {
let summary = "TOSA operation optimizations";
def TosaOptionalDecompositions : FunctionPass<"tosa-optional-decompositions"> {
let summary = "Applies Tosa operations optional decompositions";
let description = [{
"Pass to perform optimizations on TOSA operations"
Pass to apply the Tosa operations decompositions
exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h
}];

let constructor = "createTosaOptimizationPass()";
let constructor = "tosa::createTosaOptionalDecompositions()";
}

#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Expand Up @@ -68,6 +68,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
}

void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(mlir::tosa::createTosaOptionalDecompositions());
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
pm.addNestedPass<FuncOp>(createTosaToLinalgNamed());
pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaInferShapes.cpp
TosaMakeBroadcastable.cpp
TosaOptimization.cpp
TosaOptionalDecompositions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
Expand Down
115 changes: 115 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -0,0 +1,115 @@
//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}

LogicalResult matchAndRewrite(tosa::Conv2DOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value weight = op.weight();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.getType().cast<ShapedType>();

if (!inputType.hasStaticShape() || !weightType.hasRank()) {
return failure();
}

// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
return failure();
}
}

// Only works for a 1x1 kernel.
ArrayRef<int64_t> weightShape = weightType.getShape();
if (weightShape[1] != 1 || weightShape[2] != 1) {
return failure();
}

// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
auto revisedInputShapeType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getI64ArrayAttr(revisedInputShape))
.getResult();

// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getI64ArrayAttr(revisedWeightShape))
.getResult();

// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
auto fullyConnectedShapeType = RankedTensorType::get(
fullyConnectedShape,
resultType.dyn_cast<ShapedType>().getElementType());

Value fullyConnectedValue;
if (op.quantization_info()) {
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
reshapedWeight, op.bias(), op.quantization_info().getValue())
.getResult();
} else {
fullyConnectedValue = rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType,
reshapedInput, reshapedWeight, op.bias())
.getResult();
}

// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, fullyConnectedValue,
rewriter.getI64ArrayAttr(outputShape));
return success();
}
};

} // namespace

void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
RewritePatternSet &patterns) {
patterns.insert<Conv2DIsFullyConnected>(ctx);
}
121 changes: 121 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -0,0 +1,121 @@
//===- TosaDecomposeDepthwise.cpp
//------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
explicit DepthwiseConv2DIsMul(MLIRContext *context)
: OpRewritePattern(context) {}

LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value weight = op.weight();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.output().getType().cast<ShapedType>();
Type inputEType = inputType.getElementType();

if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) {
return failure();
}

// Quantization information needs to still be performed.
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
return failure();
}

// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
return failure();
}
}

// Only works for a 1x1 kernel.
ArrayRef<int64_t> weightShape = weightType.getShape();
if (weightShape[0] != 1 || weightShape[1] != 1) {
return failure();
}

// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
auto revisedInputShapeType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getI64ArrayAttr(revisedInputShape))
.getResult();

// Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getI64ArrayAttr(revisedWeightShape))
.getResult();

// Perform an elementwise mul over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
inputShape[2], inputShape[3],
weightShape[3]};
auto mulShapeType = RankedTensorType::get(
mulShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
Value mulValue =
rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
reshapedWeight, /*shift=*/0)
.getResult();

// Reshape output to [N, H, W, C * M].
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
auto outputShapeType = RankedTensorType::get(
outputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto outputValue =
rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
rewriter.getI64ArrayAttr(outputShape));

// Add in the bias.
rewriter
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
op.bias())
.getResult();
return success();
}
};

} // namespace

void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns) {
patterns.insert<DepthwiseConv2DIsMul>(ctx);
}
32 changes: 11 additions & 21 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Expand Up @@ -7,17 +7,19 @@
//
//===----------------------------------------------------------------------===//
//
// Insert reshape to binary op's input if needed to match rank
// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
// including transposing/reversing/reshaping etc..
// of the weights and input/output tenors and reversing/reshaping etc .. of
// the weights
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR//TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::tosa;
Expand Down Expand Up @@ -369,22 +371,10 @@ class TransposeConvStridedConverter
}
};

/// Pass that enables broadcast by making all input arrays have the same
/// number of dimensions. Insert RESHAPE operations to lower rank operand
struct TosaDecomposeTransposeConv
: public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
public:
void runOnFunction() override {
auto func = getFunction();
RewritePatternSet patterns(func.getContext());
patterns
.insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
func.getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};
} // namespace

std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
return std::make_unique<TosaDecomposeTransposeConv>();
void mlir::tosa::populateTosaDecomposeTransposeConv(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.insert<TransposeConvDilatedConverter>(ctx);
patterns.insert<TransposeConvStridedConverter>(ctx);
}

0 comments on commit dfd0708

Please sign in to comment.