diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index a116242dd07818..d43b1e5dc1b2c7 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -216,6 +216,28 @@ def LogOp : ComplexUnaryOp<"log", [SameOperandsAndResultType]> { let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// Log1pOp +//===----------------------------------------------------------------------===// + +def Log1pOp : ComplexUnaryOp<"log1p", [SameOperandsAndResultType]> { + let summary = "computes natural logarithm of a complex number"; + let description = [{ + The `log` op takes a single complex number and computes the natural + logarithm of one plus the given value, i.e. `log(1 + x)` or `log_e(1 + x)`, + where `x` is the input value. `e` denotes Euler's number and is + approximately equal to 2.718281. + + Example: + + ```mlir + %a = complex.log1p %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 018882ae948935..4d3d52213e55c4 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -337,6 +337,28 @@ struct LogOpConversion : public OpConversionPattern { } }; +struct Log1pOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::Log1pOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::Log1pOp::Adaptor transformed(operands); + auto type = transformed.complex().getType().cast(); + auto elementType = type.getElementType().cast(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + Value real = b.create(elementType, transformed.complex()); + Value imag = b.create(elementType, transformed.complex()); + Value one = + b.create(elementType, b.getFloatAttr(elementType, 1)); + Value realPlusOne = b.create(real, one); + Value newComplex = b.create(type, realPlusOne, imag); + rewriter.replaceOpWithNewOp(op, type, newComplex); + return success(); + } +}; + struct MulOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -535,6 +557,7 @@ void mlir::populateComplexToStandardConversionPatterns( DivOpConversion, ExpOpConversion, LogOpConversion, + Log1pOpConversion, MulOpConversion, NegOpConversion, SignOpConversion>(patterns.getContext()); @@ -558,8 +581,9 @@ void ConvertComplexToStandardPass::runOnFunction() { target.addLegalDialect(); target.addIllegalOp(); + complex::ExpOp, complex::LogOp, complex::Log1pOp, + complex::MulOp, complex::NegOp, complex::NotEqualOp, + complex::SignOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 95e6854ffa4313..765d79c0bb8ca7 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -173,6 +173,30 @@ func @complex_log(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-LABEL: func @complex_log1p +// CHECK-SAME: %[[ARG:.*]]: complex +func @complex_log1p(%arg: complex) -> complex { + %log1p = complex.log1p %arg: complex + return %log1p : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[ONE:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[REAL_PLUS_ONE:.*]] = addf %[[REAL]], %[[ONE]] : f32 +// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex +// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex +// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32 +// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 +// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex +// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_mul // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func @complex_mul(%lhs: complex, %rhs: complex) -> complex { diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir index 74b45b8ae230a5..3fc0e9299c0fd7 100644 --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -32,6 +32,9 @@ func @ops(%f: f32) { // CHECK: complex.log %[[C]] : complex %log = complex.log %complex : complex + // CHECK: complex.log1p %[[C]] : complex + %log1p = complex.log1p %complex : complex + // CHECK: complex.mul %[[C]], %[[C]] : complex %prod = complex.mul %complex, %complex : complex