diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 21797d32a22d8..f98037c9a515e 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -584,4 +584,25 @@ def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> { let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// AngleOp +//===----------------------------------------------------------------------===// + +def AngleOp : ComplexUnaryOp<"angle", + [TypesMatchWith<"complex element type matches result type", + "complex", "result", + "$_self.cast().getElementType()">]> { + let summary = "computes argument value of a complex number"; + let description = [{ + The `angle` op takes a single complex number and computes its argument value with a branch cut along the negative real axis. + + Example: + + ```mlir + %a = complex.angle %b : complex + ``` + }]; + let results = (outs AnyFloat:$result); +} + #endif // COMPLEX_OPS diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0a5124ada7a49..b104826b757fa 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -1009,6 +1009,26 @@ struct RsqrtOpConversion : public OpConversionPattern { } }; +struct AngleOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = op.getType(); + + Value real = + rewriter.create(loc, type, adaptor.getComplex()); + Value imag = + rewriter.create(loc, type, adaptor.getComplex()); + + rewriter.replaceOpWithNewOp(op, imag, real); + + return success(); + } +}; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -1016,6 +1036,7 @@ void mlir::populateComplexToStandardConversionPatterns( // clang-format off patterns.add< AbsOpConversion, + AngleOpConversion, Atan2OpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 5b37899075a4f..9aff4ecc80e4b 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -694,3 +694,16 @@ func.func @complex_rsqrt(%arg: complex) -> complex { %rsqrt = complex.rsqrt %arg : complex return %rsqrt : complex } + +// ----- + +// CHECK-LABEL: func.func @complex_angle +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_angle(%arg: complex) -> f32 { + %angle = complex.angle %arg : complex + return %angle : f32 +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32 +// CHECK: return %[[RESULT]] : f32 diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index e2df73ed3c9b0..a7e166906f4c6 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -82,6 +82,27 @@ func.func @pow(%lhs: complex, %rhs: complex) -> complex { func.return %pow : complex } +func.func @test_element(%input: tensor>, + %func: (complex) -> f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor> + + %val = func.call_indirect %func(%elem) : (complex) -> f32 + vector.print %val : f32 + scf.yield + } + func.return +} + +func.func @angle(%arg: complex) -> f32 { + %angle = complex.angle %arg : complex + func.return %angle : f32 +} + func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -251,6 +272,30 @@ func.func @entry() { %conj_func = func.constant @conj : (complex) -> complex call @test_unary(%conj_test_cast, %conj_func) : (tensor>, (complex) -> complex) -> () - + + // complex.angle test + %angle_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: -2.356 + (-1.0, 1.0), + // CHECK-NEXT: 2.356 + (0.0, 0.0), + // CHECK-NEXT: 0 + (0.0, 1.0), + // CHECK-NEXT: 1.570 + (1.0, -1.0), + // CHECK-NEXT: -0.785 + (1.0, 0.0), + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 0.785 + ]> : tensor<7xcomplex> + %angle_test_cast = tensor.cast %angle_test + : tensor<7xcomplex> to tensor> + + %angle_func = func.constant @angle : (complex) -> f32 + call @test_element(%angle_test_cast, %angle_func) + : (tensor>, (complex) -> f32) -> () + func.return }