Skip to content

[mlir][complex] Emit fma for contracted complex.mul lowering#196248

Merged
ParkHanbum merged 2 commits into
llvm:mainfrom
ParkHanbum:mlir_complex_mul_fmf
May 15, 2026
Merged

[mlir][complex] Emit fma for contracted complex.mul lowering#196248
ParkHanbum merged 2 commits into
llvm:mainfrom
ParkHanbum:mlir_complex_mul_fmf

Conversation

@ParkHanbum
Copy link
Copy Markdown
Contributor

@ParkHanbum ParkHanbum commented May 7, 2026

When complex.mul has fastmath, lower it using explicit fused
multiply-add operations for the real and imaginary components.

The lowering changes from:

real = ar * br - ai * bi
imag = ai * br + ar * bi

expressed as mul/sub/add, to:

real = fma(ar, br, -(ai * bi))
imag = fma(ar, bi, ai * br)

This is only applied when contraction is allowed. Non-contracted complex.mul
continues to lower to separate fmul/fsub/fadd operations.

Fixed: #196246

@llvmorg-github-actions
Copy link
Copy Markdown

@llvm/pr-subscribers-mlir

Author: hanbeom (ParkHanbum)

Changes

When complex.mul has fastmath<contract>, lower it using explicit fused multiply-add operations for the real and imaginary components.

The lowering changes from:

real = ar * br - ai * bi
imag = ai * br + ar * bi

expressed as mul/sub/add, to:

real = fma(ar, br, -(ai * bi))
imag = fma(ar, bi, ai * br)

This is only applied when contraction is allowed. Non-contracted complex.mul continues to lower to separate fmul/fsub/fadd operations.

Fixed: #196246


Full diff: https://github.com/llvm/llvm-project/pull/196248.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp (+24-7)
  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+29-12)
  • (modified) mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir (+5-6)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+17-22)
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index cdcb3cba55752..327a6678f9aed 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -279,13 +279,30 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
     Value lhsRe = arg.lhs.real();
     Value lhsIm = arg.lhs.imag();
 
-    Value real = LLVM::FSubOp::create(
-        rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
-        LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
-
-    Value imag = LLVM::FAddOp::create(
-        rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
-        LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+    Value real;
+    Value imag;
+    if (arith::bitEnumContainsAll(complexFMFAttr.getValue(),
+                                  arith::FastMathFlags::contract)) {
+      Value lhsImagTimesRhsImag =
+          LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf);
+      Value negLhsImagTimesRhsImag =
+          LLVM::FNegOp::create(rewriter, loc, lhsImagTimesRhsImag, fmf);
+      real = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsRe,
+                                 negLhsImagTimesRhsImag, fmf);
+
+      Value lhsImagTimesRhsReal =
+          LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf);
+      imag = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsIm,
+                                 lhsImagTimesRhsReal, fmf);
+    } else {
+      real = LLVM::FSubOp::create(
+          rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
+          LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
+
+      imag = LLVM::FAddOp::create(
+          rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+          LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+    }
 
     result.setReal(rewriter, loc, real);
     result.setImaginary(rewriter, loc, imag);
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..4dcc9a2c23d77 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -549,18 +549,35 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
     Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
     Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
-    Value lhsRealTimesRhsReal =
-        arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
-    Value lhsImagTimesRhsImag =
-        arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
-    Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
-                                       lhsImagTimesRhsImag, fmfValue);
-    Value lhsImagTimesRhsReal =
-        arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
-    Value lhsRealTimesRhsImag =
-        arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
-    Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
-                                       lhsRealTimesRhsImag, fmfValue);
+    Value real;
+    Value imag;
+    if (arith::bitEnumContainsAll(fmfValue, arith::FastMathFlags::contract)) {
+      Value lhsImagTimesRhsImag =
+          arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+      Value negLhsImagTimesRhsImag =
+          arith::NegFOp::create(b, lhsImagTimesRhsImag, fmfValue);
+      real = math::FmaOp::create(b, lhsReal, rhsReal,
+                                 negLhsImagTimesRhsImag, fmfValue);
+
+      Value lhsImagTimesRhsReal =
+          arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+      imag =
+          math::FmaOp::create(b, lhsReal, rhsImag, lhsImagTimesRhsReal,
+                              fmfValue);
+    } else {
+      Value lhsRealTimesRhsReal =
+          arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
+      Value lhsImagTimesRhsImag =
+          arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+      real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+                                   lhsImagTimesRhsImag, fmfValue);
+      Value lhsImagTimesRhsReal =
+          arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+      Value lhsRealTimesRhsImag =
+          arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+      imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+                                   lhsRealTimesRhsImag, fmfValue);
+    }
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index 4d2c12a56eaca..ccbe075d8afc5 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -245,13 +245,12 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
 // CHECK: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]]
 
-// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = llvm.fneg %[[REAL_TMP]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_RE]], %[[NEG_REAL_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
 
-// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_IM]], %[[IMAG_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
 
 // CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
 // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 7a82236b0656e..5f8838f8433b7 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -942,12 +942,11 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
 // CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
 // CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = arith.negf %[[REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL:.*]] = math.fma %[[LHS_REAL]], %[[RHS_REAL]], %[[NEG_REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = math.fma %[[LHS_REAL]], %[[RHS_IMAG]], %[[IMAG_TMP]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
@@ -964,23 +963,21 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR2:.*]] = complex.im %arg1 : complex<f32>
 // CHECK: %[[VAR4:.*]] = complex.re %arg1 : complex<f32>
 // CHECK: %[[VAR6:.*]] = complex.im %arg1 : complex<f32>
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR0]], %[[VAR4]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR10:.*]] = arith.mulf %[[VAR2]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR12:.*]] = arith.subf %[[VAR8]], %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR10:.*]] = arith.negf %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR12:.*]] = math.fma %[[VAR0]], %[[VAR4]], %[[NEG_VAR10]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR13:.*]] = arith.mulf %[[VAR2]], %[[VAR4]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR15:.*]] = arith.mulf %[[VAR0]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR17:.*]] = arith.addf %[[VAR13]], %[[VAR15]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR17:.*]] = math.fma %[[VAR0]], %[[VAR6]], %[[VAR13]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR89:.*]] = complex.create %[[VAR12]], %[[VAR17]] : complex<f32>
 // CHECK: %[[VAR90:.*]] = complex.re %arg0 : complex<f32>
 // CHECK: %[[VAR92:.*]] = complex.im %arg0 : complex<f32>
 // CHECK: %[[VAR94:.*]] = complex.re %arg0 : complex<f32>
 // CHECK: %[[VAR96:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR98:.*]] = arith.mulf %[[VAR90]], %[[VAR94]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR100:.*]] = arith.mulf %[[VAR92]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR102:.*]] = arith.subf %[[VAR98]], %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR100:.*]] = arith.negf %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR102:.*]] = math.fma %[[VAR90]], %[[VAR94]], %[[NEG_VAR100]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR103:.*]] = arith.mulf %[[VAR92]], %[[VAR94]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR105:.*]] = arith.mulf %[[VAR90]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR107:.*]] = arith.addf %[[VAR103]], %[[VAR105]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR107:.*]] = math.fma %[[VAR90]], %[[VAR96]], %[[VAR103]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR179:.*]] = complex.create %[[VAR102]], %[[VAR107]] : complex<f32>
 // CHECK: %[[VAR180:.*]] = complex.re %[[VAR89]] : complex<f32>
 // CHECK: %[[VAR181:.*]] = complex.re %[[VAR179]] : complex<f32>
@@ -1043,12 +1040,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR232:.*]] = complex.im %[[VAR229]] : complex<f32>
 // CHECK: %[[VAR234:.*]] = complex.re %arg0 : complex<f32>
 // CHECK: %[[VAR236:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR238:.*]] = arith.mulf %[[VAR230]], %[[VAR234]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR240:.*]] = arith.mulf %[[VAR232]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR242:.*]] = arith.subf %[[VAR238]], %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR240:.*]] = arith.negf %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR242:.*]] = math.fma %[[VAR230]], %[[VAR234]], %[[NEG_VAR240]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR243:.*]] = arith.mulf %[[VAR232]], %[[VAR234]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR245:.*]] = arith.mulf %[[VAR230]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR247:.*]] = arith.addf %[[VAR243]], %[[VAR245]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR247:.*]] = math.fma %[[VAR230]], %[[VAR236]], %[[VAR243]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR319:.*]] = complex.create %[[VAR242]], %[[VAR247]] : complex<f32>
 // CHECK: %[[VAR320:.*]] = complex.re %arg1 : complex<f32>
 // CHECK: %[[VAR321:.*]] = complex.re %[[VAR319]] : complex<f32>
@@ -1174,12 +1170,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR444:.*]] = complex.im %[[VAR441]] : complex<f32>
 // CHECK: %[[VAR446:.*]] = complex.re %[[VAR440]] : complex<f32>
 // CHECK: %[[VAR448:.*]] = complex.im %[[VAR440]] : complex<f32>
-// CHECK: %[[VAR450:.*]] = arith.mulf %[[VAR442]], %[[VAR446]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR452:.*]] = arith.mulf %[[VAR444]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR454:.*]] = arith.subf %[[VAR450]], %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR452:.*]] = arith.negf %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR454:.*]] = math.fma %[[VAR442]], %[[VAR446]], %[[NEG_VAR452]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR455:.*]] = arith.mulf %[[VAR444]], %[[VAR446]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR457:.*]] = arith.mulf %[[VAR442]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR459:.*]] = arith.addf %[[VAR455]], %[[VAR457]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR459:.*]] = math.fma %[[VAR442]], %[[VAR448]], %[[VAR455]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR531:.*]] = complex.create %[[VAR454]], %[[VAR459]] : complex<f32>
 // CHECK: return %[[VAR531]] : complex<f32>
 

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

When complex.mul has fastmath<contract>, lower it using explicit fused
multiply-add operations for the real and imaginary components.

The lowering changes from:

  real = ar * br - ai * bi
  imag = ai * br + ar * bi

expressed as mul/sub/add, to:

  real = fma(ar, br, -(ai * bi))
  imag = fma(ar, bi,  ai * br)

This is only applied when contraction is allowed. Non-contracted complex.mul
continues to lower to separate fmul/fsub/fadd operations.

Fixed: llvm#196246
@ParkHanbum ParkHanbum force-pushed the mlir_complex_mul_fmf branch from f6f42f2 to 180092b Compare May 15, 2026 03:53
@ParkHanbum ParkHanbum force-pushed the mlir_complex_mul_fmf branch from 180092b to b8ff19e Compare May 15, 2026 05:18
@ParkHanbum ParkHanbum merged commit ad7e6ca into llvm:main May 15, 2026
10 checks passed
pedroMVicente pushed a commit to pedroMVicente/llvm-project that referenced this pull request May 19, 2026
…6248)

When complex.mul has fastmath<contract>, lower it using explicit fused
multiply-add operations for the real and imaginary components.

The lowering changes from:

  real = ar * br - ai * bi
  imag = ai * br + ar * bi

expressed as mul/sub/add, to:

  real = fma(ar, br, -(ai * bi))
  imag = fma(ar, bi,  ai * br)

This is only applied when contraction is allowed. Non-contracted
complex.mul
continues to lower to separate fmul/fsub/fadd operations.

Fixed: llvm#196246
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir][complex] Track backend impact of explicit FMA lowering for contracted complex.mul

2 participants