Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][complex] Prevent underflow in complex.abs #79786

Merged
merged 3 commits into from
Jan 31, 2024

Conversation

Lewuathe
Copy link
Member

The previous PR was not correct on the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part.

See: #76316

@llvmbot llvmbot added mlir mlir:complex MLIR complex dialect labels Jan 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 29, 2024

@llvm/pr-subscribers-mlir-complex

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

Changes

The previous PR was not correct on the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part.

See: #76316


Full diff: https://github.com/llvm/llvm-project/pull/79786.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+44-14)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+103-22)
  • (modified) mlir/test/Conversion/ComplexToStandard/full-conversion.mlir (+23-4)
  • (modified) mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir (+48)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4c9dad9e2c17312..6e0eddab0e3aae0 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,29 +26,59 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto type = op.getType();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
-    Value real =
-        rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
-    Value imag =
-        rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
-    Value realSqr =
-        rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
-    Value imagSqr =
-        rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
-    Value sqNorm =
-        rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
-    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+    Type elementType = op.getType();
+    Value arg = adaptor.getComplex();
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 1.0));
+
+    Value real = b.create<complex::ReOp>(elementType, arg);
+    Value imag = b.create<complex::ImOp>(elementType, arg);
+
+    Value realIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+    // Real > Imag
+    Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+    Value imagSq =
+        b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+    Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+    Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+    Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
+    Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
+
+    // Real <= Imag
+    Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+    Value realSq =
+        b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+    Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+    Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+    Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
+    Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(
+        op, realIsZero, imag,
+        b.create<arith::SelectOp>(
+            imagIsZero, real,
+            b.create<arith::SelectOp>(
+                b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
+                absImag, absReal)));
+
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8fa29ea43854a4b..de519f6a706ab03 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -7,13 +7,30 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg: complex<f32>
   return %abs : f32
 }
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
 
 // -----
 
@@ -241,12 +258,28 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg: complex<f32>
   return %log : complex<f32>
 }
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +502,28 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32
 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +765,29 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
   return %abs : f32
 }
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
 
 // -----
 
@@ -807,12 +872,28 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg fastmath<nnan,contract> : complex<f32>
   return %log : complex<f32>
 }
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 9983dd46f094334..c01f3698c73868c 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,12 +6,31 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg: complex<f32>
   return %abs : f32
 }
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
 // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
-// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]]  : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]]  : f32
-// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]]  : f32
-// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
+
+// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]]  : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[REAL_ABS:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32
+// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+
+// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]]  : f32
+// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]])  : (f32) -> f32
+// CHECK: %[[IMAG_ABS:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32
+// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG_ABS]]  : f32
+
+// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
+// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32
+// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32
 // CHECK: llvm.return %[[NORM]] : f32
 
 // CHECK-LABEL: llvm.func @complex_eq
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 349b92a7aefa2e3..efdf057f08df271 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
   func.return %angle : f32
 }
 
+func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
+                      %func: (complex<f64>) -> f64) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
+
+  scf.for %i = %c0 to %size step %c1 {
+    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
+
+    %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
+    vector.print %val : f64
+    scf.yield
+  }
+  func.return
+}
+
+func.func @abs(%arg: complex<f64>) -> f64 {
+  %abs = complex.abs %arg : complex<f64>
+  func.return %abs : f64
+}
+
 func.func @entry() {
   // complex.sqrt test
   %sqrt_test = arith.constant dense<[
@@ -300,5 +321,32 @@ func.func @entry() {
   call @test_element(%angle_test_cast, %angle_func)
     : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
 
+  // complex.abs test
+  %abs_test = arith.constant dense<[
+    (1.0, 1.0),
+    // CHECK:  1.414
+    (1.0e300, 1.0e300),
+    // CHECK-NEXT:  1.41421e+300
+    (1.0e-300, 1.0e-300),
+    // CHECK-NEXT:  1.41421e-300
+    (5.0, 0.0),
+    // CHECK-NEXT:  5
+    (0.0, 6.0),
+    // CHECK-NEXT:  6
+    (7.0, 8.0),
+    // CHECK-NEXT:  10.6301
+    (-1.0, -1.0),
+    // CHECK-NEXT: 1.414
+    (-1.0e300, -1.0e300)
+    // CHECK-NEXT:  1.41421e+300
+  ]> : tensor<8xcomplex<f64>>
+  %abs_test_cast = tensor.cast %abs_test
+    :  tensor<8xcomplex<f64>> to tensor<?xcomplex<f64>>
+
+  %abs_func = func.constant @abs : (complex<f64>) -> f64
+
+  call @test_element_f64(%abs_test_cast, %abs_func)
+    : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
+
   func.return
 }

b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to get the absolute value in case of the negative value.

// CHECK-NEXT: 6
(7.0, 8.0),
// CHECK-NEXT: 10.6301
(-1.0, -1.0),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test case with the negative values.

Copy link
Member Author

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph @matthias-springer I've fixed a bug in the previous PR that caused the integration test failure in the previous change. Could you review this change when you get a chance?

#76316

@Lewuathe Lewuathe merged commit 4effff2 into llvm:main Jan 31, 2024
4 checks passed
@Lewuathe Lewuathe deleted the avoid-rounding-error-complex-abs branch January 31, 2024 00:59
@Lewuathe
Copy link
Member Author

The build which previously failed is now finished successfully.
https://lab.llvm.org/buildbot/#/builders/264/builds/6256

d0k added a commit that referenced this pull request Jan 31, 2024
This reverts commit 4effff2. It makes
`complex.abs(-1)` return `-1`.
Lewuathe added a commit to Lewuathe/llvm-project that referenced this pull request Feb 8, 2024
The previous PR was not enough about the way to handle the negative value.
It is necessary to take the absolute value of the given real (or
imaginary) part to be multiplied with the sqrt part in the case of
either is zero.

See: llvm#76316
Lewuathe added a commit to Lewuathe/llvm-project that referenced this pull request Feb 9, 2024
The previous PR was not enough about the way to handle the negative value.
It is necessary to take the absolute value of the given real (or
imaginary) part to be multiplied with the sqrt part in the case of
either is zero.

See: llvm#76316
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:complex MLIR complex dialect mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants