Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][complex] Support Fastmath flag for complex.mulf #74554

Merged

Conversation

Lewuathe
Copy link
Member

@Lewuathe Lewuathe commented Dec 6, 2023

Support fast math flag in the conversion of complex.mulf op to standard dialect.

See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

@llvmbot llvmbot added the mlir label Dec 6, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

Changes

Support fast math flag in the conversion of complex.mulf op to standard dialect.

See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981


Full diff: https://github.com/llvm/llvm-project/pull/74554.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+38-28)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+120)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index bf753c7062f36..4c9dad9e2c173 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -569,29 +569,39 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto type = cast<ComplexType>(adaptor.getLhs().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    auto fmfValue = fmf.getValue();
 
     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
-    Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal);
+    Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
-    Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag);
+    Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
-    Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal);
+    Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
-    Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag);
-
-    Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
-    Value lhsRealTimesRhsRealAbs = b.create<math::AbsFOp>(lhsRealTimesRhsReal);
-    Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
-    Value lhsImagTimesRhsImagAbs = b.create<math::AbsFOp>(lhsImagTimesRhsImag);
-    Value real =
-        b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
-
-    Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
-    Value lhsImagTimesRhsRealAbs = b.create<math::AbsFOp>(lhsImagTimesRhsReal);
-    Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
-    Value lhsRealTimesRhsImagAbs = b.create<math::AbsFOp>(lhsRealTimesRhsImag);
-    Value imag =
-        b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+    Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
+
+    Value lhsRealTimesRhsReal =
+        b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+    Value lhsRealTimesRhsRealAbs =
+        b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
+    Value lhsImagTimesRhsImag =
+        b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
+    Value lhsImagTimesRhsImagAbs =
+        b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
+    Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
+                                         lhsImagTimesRhsImag, fmfValue);
+
+    Value lhsImagTimesRhsReal =
+        b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+    Value lhsImagTimesRhsRealAbs =
+        b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
+    Value lhsRealTimesRhsImag =
+        b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
+    Value lhsRealTimesRhsImagAbs =
+        b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
+    Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
+                                         lhsRealTimesRhsImag, fmfValue);
 
     // Handle cases where the "naive" calculation results in NaN values.
     Value realIsNan =
@@ -717,20 +727,20 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     recalc = b.create<arith::AndIOp>(isNan, recalc);
 
     // Recalculate real part.
-    lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
-    lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
-    Value newReal =
-        b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
+    lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+    lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
+    Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
+                                            lhsImagTimesRhsImag, fmfValue);
     real = b.create<arith::SelectOp>(
-        recalc, b.create<arith::MulFOp>(inf, newReal), real);
+        recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
 
     // Recalculate imag part.
-    lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
-    lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
-    Value newImag =
-        b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+    lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+    lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
+    Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
+                                            lhsRealTimesRhsImag, fmfValue);
     imag = b.create<arith::SelectOp>(
-        recalc, b.create<arith::MulFOp>(inf, newImag), imag);
+        recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
     return success();
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3af28150fd5c3..8fa29ea43854a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -845,3 +845,123 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_mul_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %mul = complex.mul %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %mul : complex<f32>
+}
+// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+
+// Handle cases where the "naive" calculation results in NaN values.
+// CHECK: %[[REAL_IS_NAN:.*]] = arith.cmpf uno, %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.andi %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+
+// Case 1. LHS_REAL or LHS_IMAG are infinite.
+// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IS_INF:.*]] = arith.ori %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
+// CHECK:  %[[RHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_REAL1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_REAL1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG]] : f32
+// CHECK: %[[RHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32
+
+// Case 2. RHS_REAL or RHS_IMAG are infinite.
+// CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IS_INF:.*]] = arith.ori %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_REAL2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_REAL2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[LHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RECALC:.*]] = arith.ori %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1
+
+// Case 3. One of the pairwise products of left hand side with right hand side
+// is infinite.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE:.*]] = arith.ori %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE1:.*]] = arith.ori %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE2:.*]] = arith.ori %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
+// CHECK: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[NOT_RECALC:.*]] = arith.xori %[[RECALC]], %[[TRUE]] : i1
+// CHECK: %[[IS_SPECIAL_CASE3:.*]] = arith.andi %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL2]] : f32
+// CHECK: %[[LHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[LHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL2]] : f32
+// CHECK: %[[RHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RECALC2:.*]] = arith.ori %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
+// CHECK: %[[RECALC3:.*]] = arith.andi %[[IS_NAN]], %[[RECALC2]] : i1
+
+ // Recalculate real part.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_REAL3]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[FINAL_REAL:.*]] = arith.select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32
+
+// Recalculate imag part.
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
+
+// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file

@tpopp tpopp removed their request for review December 12, 2023 11:36
@Lewuathe
Copy link
Member Author

Lewuathe commented Jan 8, 2024

@joker-eph Sorry for bothering you. Can I ask you to review this change when you get a chance? Because you already have taken a look into several changes to support the Fastmath flag for complex dialect.

@Lewuathe
Copy link
Member Author

Lewuathe commented Jan 9, 2024

@joker-eph Thank you!

@Lewuathe Lewuathe merged commit eee71ed into llvm:main Jan 9, 2024
5 checks passed
@Lewuathe Lewuathe deleted the fast-math-flag-support-for-complex-mulf branch January 9, 2024 00:29
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants