diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index fef8fd210a495..36d0f093c6917 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -516,6 +516,11 @@ OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { return add.getRhs(); } + // subi(a, subi(a, b)) -> b + if (auto sub = getRhs().getDefiningOp()) + if (getLhs() == sub.getLhs()) + return sub.getRhs(); + return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) - b; }); @@ -2948,8 +2953,8 @@ std::optional mlir::arith::getNeutralElement(Operation *op) { Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue) { - if (auto attr = -getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue)) + if (auto attr = getIdentityValueAttr(op, resultType, builder, loc, + useOnlyFiniteValue)) return arith::ConstantOp::create(builder, loc, attr); return {}; } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index a3a0dc7adf7cc..9fde39a110473 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1369,6 +1369,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index { return %sub2 : index } +// CHECK-LABEL: @subSub1 +// CHECK-SAME: %[[ARG0:.*]]: index, +// CHECK-SAME: %[[ARG1:.*]]: index) +// CHECK: return %[[ARG1]] : index +func.func @subSub1(%arg0: index, %arg1: index) -> index { + %sub1 = arith.subi %arg0, %arg1 : index + %sub2 = arith.subi %arg0, %sub1 : index + return %sub2 : index +} + // CHECK-LABEL: @subSub0Ovf // CHECK: %[[c0:.+]] = arith.constant 0 : index // CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow : index