Skip to content

Commit

Permalink
Canonicalization for add to no-op if one of the inputs is zero
Browse files Browse the repository at this point in the history
Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D113207
  • Loading branch information
not-jenni authored and rsuderman committed Nov 4, 2021
1 parent 795ff77 commit 07a029c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Expand Up @@ -442,6 +442,8 @@ def Tosa_AddOp : Tosa_Op<"add", [
let results = (outs
Tosa_Tensor:$output
);

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Expand Up @@ -289,6 +289,55 @@ void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<NoOpOptimization>(context);
}

struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::AddOp op,
PatternRewriter &rewriter) const override {
auto input1 = op.input1();
auto input2 = op.input2();

DenseElementsAttr input1Attr;
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
input2.getType() == op.getType()) {
if (input1Attr.getType().getElementType().isa<FloatType>() &&
input1Attr.getSplatValue<APFloat>().isZero()) {
rewriter.replaceOp(op, op.input2());
return success();
}

if (input1Attr.getType().getElementType().isa<IntegerType>() &&
input1Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.input2());
return success();
}
}

DenseElementsAttr input2Attr;
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
input1.getType() == op.getType()) {
if (input2Attr.getType().getElementType().isa<FloatType>() &&
input2Attr.getSplatValue<APFloat>().isZero()) {
rewriter.replaceOp(op, op.input1());
return success();
}

if (input2Attr.getType().getElementType().isa<IntegerType>() &&
input2Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.input1());
return success();
}
}

return failure();
}
};

void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<AddZeroOptimization>(context);
}

//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Expand Up @@ -9,6 +9,38 @@ func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {

// -----

// CHECK-LABEL: @add_zero_different_shape
func @add_zero_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
// CHECK: tosa.add
%zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
return %1 : tensor<4x2x3xf32>
}

// -----

// CHECK-LABEL: @add_zero_float
func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.add
%zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}

// -----

// CHECK-LABEL: @add_zero_int
func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.add
%zeros = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}

// -----

// CHECK-LABEL: @cast_fold
func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
Expand Down

0 comments on commit 07a029c

Please sign in to comment.