diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0aa1de5fa5d9a1..9c3c4d96a301ef 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -570,39 +570,37 @@ struct Log1pOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); - arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(adaptor.getComplex()); - Value imag = b.create(adaptor.getComplex()); + Value real = b.create(elementType, adaptor.getComplex()); + Value imag = b.create(elementType, adaptor.getComplex()); Value half = b.create(elementType, b.getFloatAttr(elementType, 0.5)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create(real, one, fmf); - Value absRealPlusOne = b.create(realPlusOne, fmf); - Value absImag = b.create(imag, fmf); - - Value maxAbs = b.create(absRealPlusOne, absImag, fmf); - Value minAbs = b.create(absRealPlusOne, absImag, fmf); - - Value maxAbsOfRealPlusOneAndImagMinusOne = b.create( - b.create(arith::CmpFPredicate::OGT, realPlusOne, absImag, - fmf), - real, b.create(maxAbs, one, fmf)); - Value minMaxRatio = b.create(minAbs, maxAbs, fmf); - Value logOfMaxAbsOfRealPlusOneAndImag = - b.create(maxAbsOfRealPlusOneAndImagMinusOne, fmf); - Value logOfSqrtPart = b.create( - b.create(minMaxRatio, minMaxRatio, fmf), fmf); - Value r = b.create( - b.create(half, logOfSqrtPart, fmf), - logOfMaxAbsOfRealPlusOneAndImag, fmf); - Value resultReal = b.create( - b.create(arith::CmpFPredicate::UNO, r, r, fmf), minAbs, - r); - Value resultImag = b.create(imag, realPlusOne, fmf); + Value two = b.create(elementType, + b.getFloatAttr(elementType, 2)); + + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) + // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) + Value sumSq = b.create(real, real, fmf.getValue()); + sumSq = b.create( + sumSq, b.create(real, two, fmf.getValue()), + fmf.getValue()); + sumSq = b.create( + sumSq, b.create(imag, imag, fmf.getValue()), + fmf.getValue()); + Value logSumSq = + b.create(elementType, sumSq, fmf.getValue()); + Value resultReal = b.create(logSumSq, half, fmf.getValue()); + + Value realPlusOne = b.create(real, one, fmf.getValue()); + + Value resultImag = + b.create(elementType, imag, realPlusOne, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 43918904a09f40..f5d9499eadda48 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -300,22 +300,15 @@ func.func @complex_log1p(%arg: complex) -> complex { // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32 +// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32 +// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32 +// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32 // CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32 -// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] : f32 -// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32 -// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32 -// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32 -// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32 -// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 : f32 -// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32 -// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32 -// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] : f32 -// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] : f32 -// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] : f32 -// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] : f32 -// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] : f32 -// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] : f32 -// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex @@ -970,22 +963,15 @@ func.func @complex_log1p_with_fmf(%arg: complex) -> complex { // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath : f32 -// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] fastmath : f32 -// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath : f32 -// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath : f32 -// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath : f32 -// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath : f32 -// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 fastmath : f32 -// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32 -// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath : f32 -// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] fastmath : f32 -// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] fastmath : f32 -// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] fastmath : f32 -// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] fastmath : f32 -// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] fastmath : f32 -// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] fastmath : f32 -// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32 +// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 +// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath : f32 +// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath : f32 +// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath : f32 +// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath : f32 +// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath : f32 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex