Skip to content

Commit

Permalink
[mlir][complex] complex.arg op to calculate the angle of complex number
Browse files Browse the repository at this point in the history
Add complex.arg op which calculates the angle of complex number. The op name is inspired by the function carg in libm.

See: https://sourceware.org/newlib/libm.html#carg

Differential Revision: https://reviews.llvm.org/D128531
  • Loading branch information
Lewuathe authored and pifon2a committed Jun 27, 2022
1 parent 771c46a commit 8fa2e67
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 1 deletion.
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
Expand Up @@ -584,4 +584,25 @@ def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> {
let results = (outs Complex<AnyFloat>:$result);
}

//===----------------------------------------------------------------------===//
// AngleOp
//===----------------------------------------------------------------------===//

def AngleOp : ComplexUnaryOp<"angle",
[TypesMatchWith<"complex element type matches result type",
"complex", "result",
"$_self.cast<ComplexType>().getElementType()">]> {
let summary = "computes argument value of a complex number";
let description = [{
The `angle` op takes a single complex number and computes its argument value with a branch cut along the negative real axis.

Example:

```mlir
%a = complex.angle %b : complex<f32>
```
}];
let results = (outs AnyFloat:$result);
}

#endif // COMPLEX_OPS
21 changes: 21 additions & 0 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Expand Up @@ -1009,13 +1009,34 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
}
};

struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
using OpConversionPattern<complex::AngleOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = op.getType();

Value real =
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
Value imag =
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());

rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);

return success();
}
};

} // namespace

void mlir::populateComplexToStandardConversionPatterns(
RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AbsOpConversion,
AngleOpConversion,
Atan2OpConversion,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Expand Up @@ -694,3 +694,16 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
%rsqrt = complex.rsqrt %arg : complex<f32>
return %rsqrt : complex<f32>
}

// -----

// CHECK-LABEL: func.func @complex_angle
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_angle(%arg: complex<f32>) -> f32 {
%angle = complex.angle %arg : complex<f32>
return %angle : f32
}
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32
// CHECK: return %[[RESULT]] : f32
47 changes: 46 additions & 1 deletion mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
Expand Up @@ -82,6 +82,27 @@ func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
func.return %pow : complex<f32>
}

func.func @test_element(%input: tensor<?xcomplex<f32>>,
%func: (complex<f32>) -> f32) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>

scf.for %i = %c0 to %size step %c1 {
%elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>

%val = func.call_indirect %func(%elem) : (complex<f32>) -> f32
vector.print %val : f32
scf.yield
}
func.return
}

func.func @angle(%arg: complex<f32>) -> f32 {
%angle = complex.angle %arg : complex<f32>
func.return %angle : f32
}

func.func @entry() {
// complex.sqrt test
%sqrt_test = arith.constant dense<[
Expand Down Expand Up @@ -251,6 +272,30 @@ func.func @entry() {
%conj_func = func.constant @conj : (complex<f32>) -> complex<f32>
call @test_unary(%conj_test_cast, %conj_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()


// complex.angle test
%angle_test = arith.constant dense<[
(-1.0, -1.0),
// CHECK: -2.356
(-1.0, 1.0),
// CHECK-NEXT: 2.356
(0.0, 0.0),
// CHECK-NEXT: 0
(0.0, 1.0),
// CHECK-NEXT: 1.570
(1.0, -1.0),
// CHECK-NEXT: -0.785
(1.0, 0.0),
// CHECK-NEXT: 0
(1.0, 1.0)
// CHECK-NEXT: 0.785
]> : tensor<7xcomplex<f32>>
%angle_test_cast = tensor.cast %angle_test
: tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>

%angle_func = func.constant @angle : (complex<f32>) -> f32
call @test_element(%angle_test_cast, %angle_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()

func.return
}

0 comments on commit 8fa2e67

Please sign in to comment.