Skip to content

Commit

Permalink
[mlir][complex] Canonicalization for consecutive complex.neg
Browse files Browse the repository at this point in the history
Consecutive complex.neg are redundant so that we can canonicalize them to the original operands.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D128781
  • Loading branch information
Lewuathe authored and pifon2a committed Jun 29, 2022
1 parent d8ad018 commit 0180709
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
Expand Up @@ -365,6 +365,8 @@ def NegOp : ComplexUnaryOp<"neg", [SameOperandsAndResultType]> {
}];

let results = (outs Complex<AnyFloat>:$result);

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
Expand Up @@ -124,6 +124,20 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
return {};
}

//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//

OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");

// complex.neg(complex.neg(a)) -> a
if (auto negOp = getOperand().getDefiningOp<NegOp>())
return negOp.getOperand();

return {};
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Complex/canonicalize.mlir
Expand Up @@ -83,4 +83,14 @@ func.func @complex_add_sub_rhs() -> complex<f32> {
%sub = complex.sub %complex1, %complex2 : complex<f32>
%add = complex.add %complex2, %sub : complex<f32>
return %add : complex<f32>
}

// CHECK-LABEL: func @complex_neg_neg
func.func @complex_neg_neg() -> complex<f32> {
%complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
// CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
// CHECK-NEXT: return %[[CPLX:.*]] : complex<f32>
%neg1 = complex.neg %complex1 : complex<f32>
%neg2 = complex.neg %neg1 : complex<f32>
return %neg2 : complex<f32>
}

0 comments on commit 0180709

Please sign in to comment.