diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index f98037c9a515e..5ca24398843cb 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -73,6 +73,8 @@ def AddOp : ComplexArithmeticOp<"add"> { %a = complex.add %b, %c : complex ``` }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 7545d3feec17b..0390a00cf6844 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" using namespace mlir; using namespace mlir::complex; @@ -103,6 +104,26 @@ OpFoldResult ReOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary op takes 2 operands"); + + // complex.add(complex.sub(a, b), b) -> a + if (auto sub = getLhs().getDefiningOp()) + if (getRhs() == sub.getRhs()) + return sub.getLhs(); + + // complex.add(b, complex.sub(a, b)) -> a + if (auto sub = getRhs().getDefiningOp()) + if (getLhs() == sub.getRhs()) + return sub.getLhs(); + + return {}; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir index 2d492a223d4c7..8bca3232774a1 100644 --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -62,3 +62,25 @@ func.func @imag_of_create_op() -> f32 { %1 = complex.im %complex : complex return %1 : f32 } + +// CHECK-LABEL: func @complex_add_sub_lhs +func.func @complex_add_sub_lhs() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + %complex2 = complex.constant [0.0 : f32, 2.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %sub = complex.sub %complex1, %complex2 : complex + %add = complex.add %sub, %complex2 : complex + return %add : complex +} + +// CHECK-LABEL: func @complex_add_sub_rhs +func.func @complex_add_sub_rhs() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + %complex2 = complex.constant [0.0 : f32, 2.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %sub = complex.sub %complex1, %complex2 : complex + %add = complex.add %complex2, %sub : complex + return %add : complex +} \ No newline at end of file