From 8fa2e67979e56db3cc511ff1af920b4fa02fb473 Mon Sep 17 00:00:00 2001 From: Lewuathe Date: Mon, 27 Jun 2022 14:29:15 +0200 Subject: [PATCH] [mlir][complex] complex.arg op to calculate the angle of complex number Add complex.arg op which calculates the angle of complex number. The op name is inspired by the function carg in libm. See: https://sourceware.org/newlib/libm.html#carg Differential Revision: https://reviews.llvm.org/D128531 --- .../mlir/Dialect/Complex/IR/ComplexOps.td | 21 +++++++++ .../ComplexToStandard/ComplexToStandard.cpp | 21 +++++++++ .../convert-to-standard.mlir | 13 +++++ .../Dialect/Complex/CPU/correctness.mlir | 47 ++++++++++++++++++- 4 files changed, 101 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 21797d32a22d8..f98037c9a515e 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -584,4 +584,25 @@ def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> { let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// AngleOp +//===----------------------------------------------------------------------===// + +def AngleOp : ComplexUnaryOp<"angle", + [TypesMatchWith<"complex element type matches result type", + "complex", "result", + "$_self.cast().getElementType()">]> { + let summary = "computes argument value of a complex number"; + let description = [{ + The `angle` op takes a single complex number and computes its argument value with a branch cut along the negative real axis. + + Example: + + ```mlir + %a = complex.angle %b : complex + ``` + }]; + let results = (outs AnyFloat:$result); +} + #endif // COMPLEX_OPS diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0a5124ada7a49..b104826b757fa 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -1009,6 +1009,26 @@ struct RsqrtOpConversion : public OpConversionPattern { } }; +struct AngleOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = op.getType(); + + Value real = + rewriter.create(loc, type, adaptor.getComplex()); + Value imag = + rewriter.create(loc, type, adaptor.getComplex()); + + rewriter.replaceOpWithNewOp(op, imag, real); + + return success(); + } +}; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -1016,6 +1036,7 @@ void mlir::populateComplexToStandardConversionPatterns( // clang-format off patterns.add< AbsOpConversion, + AngleOpConversion, Atan2OpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 5b37899075a4f..9aff4ecc80e4b 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -694,3 +694,16 @@ func.func @complex_rsqrt(%arg: complex) -> complex { %rsqrt = complex.rsqrt %arg : complex return %rsqrt : complex } + +// ----- + +// CHECK-LABEL: func.func @complex_angle +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_angle(%arg: complex) -> f32 { + %angle = complex.angle %arg : complex + return %angle : f32 +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32 +// CHECK: return %[[RESULT]] : f32 diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index e2df73ed3c9b0..a7e166906f4c6 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -82,6 +82,27 @@ func.func @pow(%lhs: complex, %rhs: complex) -> complex { func.return %pow : complex } +func.func @test_element(%input: tensor>, + %func: (complex) -> f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor> + + %val = func.call_indirect %func(%elem) : (complex) -> f32 + vector.print %val : f32 + scf.yield + } + func.return +} + +func.func @angle(%arg: complex) -> f32 { + %angle = complex.angle %arg : complex + func.return %angle : f32 +} + func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -251,6 +272,30 @@ func.func @entry() { %conj_func = func.constant @conj : (complex) -> complex call @test_unary(%conj_test_cast, %conj_func) : (tensor>, (complex) -> complex) -> () - + + // complex.angle test + %angle_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: -2.356 + (-1.0, 1.0), + // CHECK-NEXT: 2.356 + (0.0, 0.0), + // CHECK-NEXT: 0 + (0.0, 1.0), + // CHECK-NEXT: 1.570 + (1.0, -1.0), + // CHECK-NEXT: -0.785 + (1.0, 0.0), + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 0.785 + ]> : tensor<7xcomplex> + %angle_test_cast = tensor.cast %angle_test + : tensor<7xcomplex> to tensor> + + %angle_func = func.constant @angle : (complex) -> f32 + call @test_element(%angle_test_cast, %angle_func) + : (tensor>, (complex) -> f32) -> () + func.return }