Skip to content

Commit

Permalink
[TOSA] Loosen folding restrictions for tosa.add,tosa.sub, tosa.mul
Browse files Browse the repository at this point in the history
Allow folding of different tensor types when the constant tensor is broadcast.
Removed redundant and incorrect AddZero and MulOne canonical optimizations.

Reviewed By: rsuderman, eric-k256

Differential Revision: https://reviews.llvm.org/D145738
  • Loading branch information
sjw36 authored and rsuderman committed Mar 30, 2023
1 parent 280ece9 commit 3fbc6fd
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 168 deletions.
2 changes: 0 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Expand Up @@ -477,7 +477,6 @@ def Tosa_AddOp : Tosa_Op<"add", [
Tosa_Tensor:$output
);

let hasCanonicalizer = 1;
let hasFolder = 1;
}

Expand Down Expand Up @@ -796,7 +795,6 @@ def Tosa_MulOp : Tosa_Op<"mul", [
Tosa_Tensor:$output
);

let hasCanonicalizer = 1;
let hasFolder = 1;
}

Expand Down
188 changes: 36 additions & 152 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Expand Up @@ -246,92 +246,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}

struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::AddOp op,
PatternRewriter &rewriter) const override {
auto input1 = op.getInput1();
auto input2 = op.getInput2();

DenseElementsAttr input1Attr;
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
input2.getType() == op.getType()) {
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
input1Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.getInput2());
return success();
}
}

DenseElementsAttr input2Attr;
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
input1.getType() == op.getType()) {
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
input2Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.getInput1());
return success();
}
}

return failure();
}
};

void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<AddZeroOptimization>(context);
}

struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::MulOp op,
PatternRewriter &rewriter) const override {
auto input1 = op.getInput1();
auto input2 = op.getInput2();

DenseElementsAttr input1Attr;
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
input2.getType() == op.getType()) {
if (input1Attr.getType().getElementType().isa<FloatType>() &&
input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
rewriter.replaceOp(op, op.getInput2());
return success();
}

if (input1Attr.getType().getElementType().isa<IntegerType>() &&
matchPattern(input1, m_One())) {
rewriter.replaceOp(op, op.getInput2());
return success();
}
}

DenseElementsAttr input2Attr;
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
input1.getType() == op.getType()) {
if (input2Attr.getType().getElementType().isa<FloatType>() &&
input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
rewriter.replaceOp(op, op.getInput1());
return success();
}

if (input2Attr.getType().getElementType().isa<IntegerType>() &&
matchPattern(input2, m_One())) {
rewriter.replaceOp(op, op.getInput1());
return success();
}
}

return failure();
}
};

void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MulOneOptimization>(context);
}

struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -609,44 +523,47 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
return {};
}

static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (elemType.isa<FloatType>())
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
if (elemType.isa<IntegerType>())
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
return false;
}

static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
if (elemType.isa<FloatType>())
return val && val.isSplat() &&
val.getSplatValue<APFloat>().isExactlyValue(1.0);
if (elemType.isa<IntegerType>()) {
const int64_t shifted = 1LL << shift;
return val && val.isSplat() &&
val.getSplatValue<APInt>().getSExtValue() == shifted;
}
return false;
}

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};

auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();

if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (lhsAttr.getSplatValue<APFloat>().isZero())
return getInput2();
}

if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (rhsAttr.getSplatValue<APFloat>().isZero())
return getInput1();
}

if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
if (lhsAttr.getSplatValue<APInt>().isZero())
return getInput2();
}

if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
if (rhsAttr.getSplatValue<APInt>().isZero())
return getInput1();
}
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
return getInput2();

if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
lhsTy);
resultTy);
}

OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -724,50 +641,26 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};

auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();

if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
auto val = lhsAttr.getSplatValue<APFloat>();
if (val.isZero())
const int64_t shift = resultETy.isa<IntegerType>() ? getShift() : 0;
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
return lhsAttr;
if (val.isExactlyValue(1.0))
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}

if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
auto val = rhsAttr.getSplatValue<APFloat>();
if (val.isZero())
return rhsAttr;
if (val.isExactlyValue(1.0))
return lhs;
}

if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
auto val = lhsAttr.getSplatValue<APInt>();
if (val.isZero())
return lhsAttr;
const int64_t shift = getShift();
const int64_t shifted = 1LL << shift;
if (val.getSExtValue() == shifted)
return rhs;
}

if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
auto val = rhsAttr.getSplatValue<APInt>();
const int64_t shift = getShift();
const int64_t shifted = 1LL << shift;
if (val.isZero())
if (lhsTy == resultTy) {
if (isSplatZero(resultETy, rhsAttr))
return rhsAttr;
if (val.getSExtValue() == shifted)
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
}

return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
}

OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
Expand All @@ -776,28 +669,19 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};

auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();

if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (rhsAttr.getSplatValue<APFloat>().isZero())
return getInput1();
}

if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
if (rhsAttr.getSplatValue<APInt>().isZero())
return getInput1();
}
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();

if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
lhsTy);
resultTy);
}

namespace {
Expand Down
29 changes: 15 additions & 14 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Expand Up @@ -7,15 +7,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}

// CHECK-LABEL: @add_zero_different_shape
func.func @add_zero_different_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
// CHECK: tosa.add
%zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32>
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32>
// CHECK-LABEL: @add_bcast_zero_int
func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
// CHECK-NOT: tosa.add
// CHECK: return %arg0
%zeros = "tosa.const"() {value = dense<0> : tensor<1x1x1xi32>} : () -> tensor<1x1x1xi32>
%1 = "tosa.add"(%arg0, %zeros) : (tensor<4x2x3xi32>, tensor<1x1x1xi32>) -> tensor<4x2x3xi32>
return %1 : tensor<4x2x3xi32>
}


// CHECK-LABEL: @add_zero_int
func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
Expand Down Expand Up @@ -176,14 +176,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
return %1 : tensor<?x?xi32>
}

// CHECK-LABEL: @mul_one_different_shape
func.func @mul_one_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
// CHECK: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
%1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
return %1 : tensor<4x2x3xf32>
}

// CHECK-LABEL: @mul_one_float
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
Expand All @@ -193,6 +185,15 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
return %1 : tensor<2x3xf32>
}

// CHECK-LABEL: @mul_bcast_one_float
func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%1 = "tosa.mul"(%ones, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}

// CHECK-LABEL: @mul_one_int
func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
Expand Down

0 comments on commit 3fbc6fd

Please sign in to comment.