diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td index b6ec45ced3512..acdb3f200051b 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -764,6 +764,7 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> { math, contraction, rounding mode, and other controls. }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -773,6 +774,7 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> { def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> { let summary = "floating point division operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td index d6b141e88b9a9..647ddfee74d0f 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -187,4 +187,22 @@ def OrOfExtSI : Pat<(Arith_OrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_OrIOp $x, $y)), [(Constraint> $x, $y)]>; +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +// mulf(negf(x), negf(y)) -> mulf(x,y) +def MulFOfNegF : + Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y), + [(Constraint> $x, $y)]>; + +//===----------------------------------------------------------------------===// +// DivFOp +//===----------------------------------------------------------------------===// + +// divf(negf(x), negf(y)) -> divf(x,y) +def DivFOfNegF : + Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y), + [(Constraint> $x, $y)]>; + #endif // ARITHMETIC_PATTERNS diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index f1caac9943e64..8b9175701ff87 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -743,6 +743,11 @@ OpFoldResult arith::MulFOp::fold(ArrayRef operands) { operands, [](const APFloat &a, const APFloat &b) { return a * b; }); } +void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// @@ -756,6 +761,11 @@ OpFoldResult arith::DivFOp::fold(ArrayRef operands) { operands, [](const APFloat &a, const APFloat &b) { return a / b; }); } +void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir index a9472c4264257..2f563eb598d79 100644 --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -948,6 +948,16 @@ func.func @test_mulf(%arg0 : f32) -> (f32, f32, f32, f32) { return %0, %1, %2, %3 : f32, f32, f32, f32 } +// CHECK-LABEL: @test_mulf1( +func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) { + // CHECK-NEXT: %[[X:.+]] = arith.mulf %arg0, %arg1 : f32 + // CHECK-NEXT: return %[[X]] + %0 = arith.negf %arg0 : f32 + %1 = arith.negf %arg1 : f32 + %2 = arith.mulf %0, %1 : f32 + return %2 : f32 +} + // ----- // CHECK-LABEL: @test_divf( @@ -961,6 +971,16 @@ func.func @test_divf(%arg0 : f64) -> (f64, f64) { return %0, %1 : f64, f64 } +// CHECK-LABEL: @test_divf1( +func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) { + // CHECK-NEXT: %[[X:.+]] = arith.divf %arg0, %arg1 : f32 + // CHECK-NEXT: return %[[X]] + %0 = arith.negf %arg0 : f32 + %1 = arith.negf %arg1 : f32 + %2 = arith.divf %0, %1 : f32 + return %2 : f32 +} + // ----- // CHECK-LABEL: @test_cmpf(