[mlir][complex] Emit fma for contracted complex.mul lowering#196248
Merged
Conversation
|
@llvm/pr-subscribers-mlir Author: hanbeom (ParkHanbum) ChangesWhen complex.mul has fastmath<contract>, lower it using explicit fused multiply-add operations for the real and imaginary components. The lowering changes from: real = ar * br - ai * bi expressed as mul/sub/add, to: real = fma(ar, br, -(ai * bi)) This is only applied when contraction is allowed. Non-contracted complex.mul continues to lower to separate fmul/fsub/fadd operations. Fixed: #196246 Full diff: https://github.com/llvm/llvm-project/pull/196248.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index cdcb3cba55752..327a6678f9aed 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -279,13 +279,30 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
Value lhsRe = arg.lhs.real();
Value lhsIm = arg.lhs.imag();
- Value real = LLVM::FSubOp::create(
- rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
- LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
-
- Value imag = LLVM::FAddOp::create(
- rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
- LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+ Value real;
+ Value imag;
+ if (arith::bitEnumContainsAll(complexFMFAttr.getValue(),
+ arith::FastMathFlags::contract)) {
+ Value lhsImagTimesRhsImag =
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf);
+ Value negLhsImagTimesRhsImag =
+ LLVM::FNegOp::create(rewriter, loc, lhsImagTimesRhsImag, fmf);
+ real = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsRe,
+ negLhsImagTimesRhsImag, fmf);
+
+ Value lhsImagTimesRhsReal =
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf);
+ imag = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsIm,
+ lhsImagTimesRhsReal, fmf);
+ } else {
+ real = LLVM::FSubOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
+
+ imag = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+ }
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..4dcc9a2c23d77 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -549,18 +549,35 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
- Value lhsRealTimesRhsReal =
- arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
- Value lhsImagTimesRhsImag =
- arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
- Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
- lhsImagTimesRhsImag, fmfValue);
- Value lhsImagTimesRhsReal =
- arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
- Value lhsRealTimesRhsImag =
- arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
- Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
- lhsRealTimesRhsImag, fmfValue);
+ Value real;
+ Value imag;
+ if (arith::bitEnumContainsAll(fmfValue, arith::FastMathFlags::contract)) {
+ Value lhsImagTimesRhsImag =
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ Value negLhsImagTimesRhsImag =
+ arith::NegFOp::create(b, lhsImagTimesRhsImag, fmfValue);
+ real = math::FmaOp::create(b, lhsReal, rhsReal,
+ negLhsImagTimesRhsImag, fmfValue);
+
+ Value lhsImagTimesRhsReal =
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+ imag =
+ math::FmaOp::create(b, lhsReal, rhsImag, lhsImagTimesRhsReal,
+ fmfValue);
+ } else {
+ Value lhsRealTimesRhsReal =
+ arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
+ Value lhsImagTimesRhsImag =
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+ lhsImagTimesRhsImag, fmfValue);
+ Value lhsImagTimesRhsReal =
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+ Value lhsRealTimesRhsImag =
+ arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+ imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+ lhsRealTimesRhsImag, fmfValue);
+ }
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
return success();
}
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index 4d2c12a56eaca..ccbe075d8afc5 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -245,13 +245,12 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
// CHECK: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]]
-// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = llvm.fneg %[[REAL_TMP]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_RE]], %[[NEG_REAL_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
-// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_IM]], %[[IMAG_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 7a82236b0656e..5f8838f8433b7 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -942,12 +942,11 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = arith.negf %[[REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL:.*]] = math.fma %[[LHS_REAL]], %[[RHS_REAL]], %[[NEG_REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = math.fma %[[LHS_REAL]], %[[RHS_IMAG]], %[[IMAG_TMP]] fastmath<nnan,contract> : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
@@ -964,23 +963,21 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR2:.*]] = complex.im %arg1 : complex<f32>
// CHECK: %[[VAR4:.*]] = complex.re %arg1 : complex<f32>
// CHECK: %[[VAR6:.*]] = complex.im %arg1 : complex<f32>
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR0]], %[[VAR4]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR10:.*]] = arith.mulf %[[VAR2]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR12:.*]] = arith.subf %[[VAR8]], %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR10:.*]] = arith.negf %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR12:.*]] = math.fma %[[VAR0]], %[[VAR4]], %[[NEG_VAR10]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR13:.*]] = arith.mulf %[[VAR2]], %[[VAR4]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR15:.*]] = arith.mulf %[[VAR0]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR17:.*]] = arith.addf %[[VAR13]], %[[VAR15]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR17:.*]] = math.fma %[[VAR0]], %[[VAR6]], %[[VAR13]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR89:.*]] = complex.create %[[VAR12]], %[[VAR17]] : complex<f32>
// CHECK: %[[VAR90:.*]] = complex.re %arg0 : complex<f32>
// CHECK: %[[VAR92:.*]] = complex.im %arg0 : complex<f32>
// CHECK: %[[VAR94:.*]] = complex.re %arg0 : complex<f32>
// CHECK: %[[VAR96:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR98:.*]] = arith.mulf %[[VAR90]], %[[VAR94]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR100:.*]] = arith.mulf %[[VAR92]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR102:.*]] = arith.subf %[[VAR98]], %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR100:.*]] = arith.negf %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR102:.*]] = math.fma %[[VAR90]], %[[VAR94]], %[[NEG_VAR100]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR103:.*]] = arith.mulf %[[VAR92]], %[[VAR94]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR105:.*]] = arith.mulf %[[VAR90]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR107:.*]] = arith.addf %[[VAR103]], %[[VAR105]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR107:.*]] = math.fma %[[VAR90]], %[[VAR96]], %[[VAR103]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR179:.*]] = complex.create %[[VAR102]], %[[VAR107]] : complex<f32>
// CHECK: %[[VAR180:.*]] = complex.re %[[VAR89]] : complex<f32>
// CHECK: %[[VAR181:.*]] = complex.re %[[VAR179]] : complex<f32>
@@ -1043,12 +1040,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR232:.*]] = complex.im %[[VAR229]] : complex<f32>
// CHECK: %[[VAR234:.*]] = complex.re %arg0 : complex<f32>
// CHECK: %[[VAR236:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR238:.*]] = arith.mulf %[[VAR230]], %[[VAR234]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR240:.*]] = arith.mulf %[[VAR232]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR242:.*]] = arith.subf %[[VAR238]], %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR240:.*]] = arith.negf %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR242:.*]] = math.fma %[[VAR230]], %[[VAR234]], %[[NEG_VAR240]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR243:.*]] = arith.mulf %[[VAR232]], %[[VAR234]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR245:.*]] = arith.mulf %[[VAR230]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR247:.*]] = arith.addf %[[VAR243]], %[[VAR245]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR247:.*]] = math.fma %[[VAR230]], %[[VAR236]], %[[VAR243]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR319:.*]] = complex.create %[[VAR242]], %[[VAR247]] : complex<f32>
// CHECK: %[[VAR320:.*]] = complex.re %arg1 : complex<f32>
// CHECK: %[[VAR321:.*]] = complex.re %[[VAR319]] : complex<f32>
@@ -1174,12 +1170,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR444:.*]] = complex.im %[[VAR441]] : complex<f32>
// CHECK: %[[VAR446:.*]] = complex.re %[[VAR440]] : complex<f32>
// CHECK: %[[VAR448:.*]] = complex.im %[[VAR440]] : complex<f32>
-// CHECK: %[[VAR450:.*]] = arith.mulf %[[VAR442]], %[[VAR446]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR452:.*]] = arith.mulf %[[VAR444]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR454:.*]] = arith.subf %[[VAR450]], %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR452:.*]] = arith.negf %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR454:.*]] = math.fma %[[VAR442]], %[[VAR446]], %[[NEG_VAR452]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR455:.*]] = arith.mulf %[[VAR444]], %[[VAR446]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR457:.*]] = arith.mulf %[[VAR442]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR459:.*]] = arith.addf %[[VAR455]], %[[VAR457]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR459:.*]] = math.fma %[[VAR442]], %[[VAR448]], %[[VAR455]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR531:.*]] = complex.create %[[VAR454]], %[[VAR459]] : complex<f32>
// CHECK: return %[[VAR531]] : complex<f32>
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
akuegel
approved these changes
May 11, 2026
When complex.mul has fastmath<contract>, lower it using explicit fused multiply-add operations for the real and imaginary components. The lowering changes from: real = ar * br - ai * bi imag = ai * br + ar * bi expressed as mul/sub/add, to: real = fma(ar, br, -(ai * bi)) imag = fma(ar, bi, ai * br) This is only applied when contraction is allowed. Non-contracted complex.mul continues to lower to separate fmul/fsub/fadd operations. Fixed: llvm#196246
f6f42f2 to
180092b
Compare
180092b to
b8ff19e
Compare
pedroMVicente
pushed a commit
to pedroMVicente/llvm-project
that referenced
this pull request
May 19, 2026
…6248) When complex.mul has fastmath<contract>, lower it using explicit fused multiply-add operations for the real and imaginary components. The lowering changes from: real = ar * br - ai * bi imag = ai * br + ar * bi expressed as mul/sub/add, to: real = fma(ar, br, -(ai * bi)) imag = fma(ar, bi, ai * br) This is only applied when contraction is allowed. Non-contracted complex.mul continues to lower to separate fmul/fsub/fadd operations. Fixed: llvm#196246
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When complex.mul has fastmath, lower it using explicit fused
multiply-add operations for the real and imaginary components.
The lowering changes from:
real = ar * br - ai * bi
imag = ai * br + ar * bi
expressed as mul/sub/add, to:
real = fma(ar, br, -(ai * bi))
imag = fma(ar, bi, ai * br)
This is only applied when contraction is allowed. Non-contracted complex.mul
continues to lower to separate fmul/fsub/fadd operations.
Fixed: #196246