diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 5ca24398843cb..b76276c726ee5 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -365,6 +365,8 @@ def NegOp : ComplexUnaryOp<"neg", [SameOperandsAndResultType]> { }]; let results = (outs Complex:$result); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 0390a00cf6844..00dbb564481ee 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -124,6 +124,20 @@ OpFoldResult AddOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// NegOp +//===----------------------------------------------------------------------===// + +OpFoldResult NegOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary op takes 1 operand"); + + // complex.neg(complex.neg(a)) -> a + if (auto negOp = getOperand().getDefiningOp()) + return negOp.getOperand(); + + return {}; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir index 8bca3232774a1..093e92f6a0308 100644 --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -83,4 +83,14 @@ func.func @complex_add_sub_rhs() -> complex { %sub = complex.sub %complex1, %complex2 : complex %add = complex.add %complex2, %sub : complex return %add : complex +} + +// CHECK-LABEL: func @complex_neg_neg +func.func @complex_neg_neg() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %neg1 = complex.neg %complex1 : complex + %neg2 = complex.neg %neg1 : complex + return %neg2 : complex } \ No newline at end of file