Skip to content

[mlir][math] Use APFloat::SemanticsToEnum in constant folding#193914

Merged
CoTinker merged 1 commit into
llvm:mainfrom
CoTinker:semantic
Apr 24, 2026
Merged

[mlir][math] Use APFloat::SemanticsToEnum in constant folding#193914
CoTinker merged 1 commit into
llvm:mainfrom
CoTinker:semantic

Conversation

@CoTinker
Copy link
Copy Markdown
Contributor

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking floating-point semantics. Inferring semantics from bitwidth     is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 24, 2026

@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.


Full diff: https://github.com/llvm/llvm-project/pull/193914.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+90-90)
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bec95f58260be 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -62,10 +62,10 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(acos(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(acosf(a.convertToFloat()));
         default:
           return {};
@@ -80,10 +80,10 @@ OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(acosh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(acoshf(a.convertToFloat()));
         default:
           return {};
@@ -98,10 +98,10 @@ OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(asin(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(asinf(a.convertToFloat()));
         default:
           return {};
@@ -116,10 +116,10 @@ OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(asinh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(asinhf(a.convertToFloat()));
         default:
           return {};
@@ -134,10 +134,10 @@ OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atan(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atanf(a.convertToFloat()));
         default:
           return {};
@@ -152,10 +152,10 @@ OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atanh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atanhf(a.convertToFloat()));
         default:
           return {};
@@ -174,15 +174,14 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
         if (a.isZero() && b.isZero())
           return llvm::APFloat::getNaN(a.getSemantics());
 
-        if (a.getSizeInBits(a.getSemantics()) == 64 &&
-            b.getSizeInBits(b.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32 &&
-            b.getSizeInBits(b.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -219,10 +218,10 @@ OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(cos(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(cosf(a.convertToFloat()));
         default:
           return {};
@@ -237,10 +236,10 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(cosh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(coshf(a.convertToFloat()));
         default:
           return {};
@@ -255,10 +254,10 @@ OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sin(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sinf(a.convertToFloat()));
         default:
           return {};
@@ -273,10 +272,10 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sinh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sinhf(a.convertToFloat()));
         default:
           return {};
@@ -331,10 +330,10 @@ OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(erf(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(erff(a.convertToFloat()));
         default:
           return {};
@@ -424,13 +423,14 @@ OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        if (a.getSizeInBits(a.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log(a.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(logf(a.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -444,13 +444,14 @@ OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        if (a.getSizeInBits(a.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log2(a.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(log2f(a.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -464,10 +465,10 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log10(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(log10f(a.convertToFloat()));
         default:
           return {};
@@ -482,12 +483,12 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
 OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           if ((a + APFloat(1.0)).isNegative())
             return {};
           return APFloat(log1p(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           if ((a + APFloat(1.0f)).isNegative())
             return {};
           return APFloat(log1pf(a.convertToFloat()));
@@ -505,15 +506,14 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
       adaptor.getOperands(),
       [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
-        if (a.getSizeInBits(a.getSemantics()) == 64 &&
-            b.getSizeInBits(b.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32 &&
-            b.getSizeInBits(b.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -528,10 +528,10 @@ OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
           return {};
 
         APFloat one(a.getSemantics(), 1);
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return one / APFloat(sqrt(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return one / APFloat(sqrtf(a.convertToFloat()));
         default:
           return {};
@@ -549,10 +549,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sqrt(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sqrtf(a.convertToFloat()));
         default:
           return {};
@@ -567,10 +567,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(exp(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(expf(a.convertToFloat()));
         default:
           return {};
@@ -585,10 +585,10 @@ OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(exp2(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(exp2f(a.convertToFloat()));
         default:
           return {};
@@ -603,10 +603,10 @@ OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(expm1(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(expm1f(a.convertToFloat()));
         default:
           return {};
@@ -685,10 +685,10 @@ OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(tan(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(tanf(a.convertToFloat()));
         default:
           return {};
@@ -703,10 +703,10 @@ OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(tanh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(tanhf(a.convertToFloat()));
         default:
           return {};
@@ -747,10 +747,10 @@ OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(round(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(roundf(a.convertToFloat()));
         default:
           return {};
@@ -765,10 +765,10 @@ OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(trunc(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(truncf(a.convertToFloat()));
         default:
           return {};

@CoTinker CoTinker merged commit 170f030 into llvm:main Apr 24, 2026
13 checks passed
@CoTinker CoTinker deleted the semantic branch April 24, 2026 08:56
yingopq pushed a commit to yingopq/llvm-project that referenced this pull request Apr 29, 2026
…93914)

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking
floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.
KHicketts pushed a commit to KHicketts/llvm-project that referenced this pull request Apr 30, 2026
…93914)

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking
floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants