Skip to content

Commit

Permalink
[mlir][complex] Add complex.conj op
Browse files Browse the repository at this point in the history
Add complex.conj op to calculate the complex conjugate which is widely used for the mathematical operation on the complex space.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D127181
  • Loading branch information
Lewuathe authored and pifon2a committed Jun 7, 2022
1 parent 15d82c6 commit 62a34f6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
Expand Up @@ -564,4 +564,24 @@ def TanOp : ComplexUnaryOp<"tan", [SameOperandsAndResultType]> {
let results = (outs Complex<AnyFloat>:$result);
}

//===----------------------------------------------------------------------===//
// Conj
//===----------------------------------------------------------------------===//

def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> {
let summary = "Calculate the complex conjugate";
let description = [{
The `conj` op takes a single complex number and computes the
complex conjugate.

Example:

```mlir
%a = complex.conj %b: complex<f32>
```
}];

let results = (outs Complex<AnyFloat>:$result);
}

#endif // COMPLEX_OPS
31 changes: 27 additions & 4 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Expand Up @@ -885,6 +885,27 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
}
};

struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
using OpConversionPattern<complex::ConjOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = adaptor.getComplex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
Value imag =
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);

rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);

return success();
}
};

} // namespace

void mlir::populateComplexToStandardConversionPatterns(
Expand All @@ -893,23 +914,25 @@ void mlir::populateComplexToStandardConversionPatterns(
patterns.add<
AbsOpConversion,
Atan2OpConversion,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
ConjOpConversion,
CosOpConversion,
DivOpConversion,
ExpOpConversion,
Expm1OpConversion,
LogOpConversion,
Log1pOpConversion,
LogOpConversion,
MulOpConversion,
NegOpConversion,
SignOpConversion,
SinOpConversion,
SqrtOpConversion,
TanOpConversion,
TanhOpConversion>(patterns.getContext());
TanhOpConversion
>(patterns.getContext());
// clang-format on
}

Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Expand Up @@ -663,3 +663,17 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
return %sqrt : complex<f32>
}

// -----

// CHECK-LABEL: func @complex_conj
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_conj(%arg: complex<f32>) -> complex<f32> {
%conj = complex.conj %arg: complex<f32>
return %conj : complex<f32>
}
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>

0 comments on commit 62a34f6

Please sign in to comment.