Skip to content

[mlir][arith] Add rounding mode flags to binary arithmetic operations#188458

Merged
matthias-springer merged 8 commits into
mainfrom
users/matthias-springer/arith_rounding
Apr 17, 2026
Merged

[mlir][arith] Add rounding mode flags to binary arithmetic operations#188458
matthias-springer merged 8 commits into
mainfrom
users/matthias-springer/arith_rounding

Conversation

@matthias-springer
Copy link
Copy Markdown
Member

@matthias-springer matthias-springer commented Mar 25, 2026

Add rounding mode flags for addf, subf, mulf, divf. This addresses a TODO in the op description.

The folder now takes into account the specified rounding mode. If no rounding mode is specified, the folders/canonicalizations default to rmNearestTiesToEven. (This behavior has not changed.) This is documented in the top-level arith dialect documentation. The default arith rounding mode applies only to "internal" transformations such as foldings/canonicalizations. In case of an unspecified explicit rounding mode, the runtime behavior is up to the target backend.

Also add a lowering to LLVM intrinsics such as llvm.intr.experimental.constrained.fadd.

Assisted-by: claude-4.6-opus-high

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Mar 25, 2026

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-math
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add rounding mode flags for addf, subf, mulf, divf, remf. This addresses a TODO in the op description.

Also improve the folder to take into account the rounding mode. If no rounding mode is specified, the folder uses rmNearestTiesToEven. (This behavior has not changed. We may want to deactivate folding in the future when no rounding mode is provided and the result cannot be represented exactly.)

Also add a lowering to LLVM intrinsics such as llvm.intr.experimental.constrained.fadd.


Patch is 33.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/188458.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h (+5)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+76-14)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+40-15)
  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+4-4)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-5)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+6-6)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+28-8)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (+9-9)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+61)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+75)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+17)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index fccfe4897114e..c0773c9b69b6b 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,6 +157,11 @@ class AttrConverterConstrainedFPToLLVM {
       convertedAttr.set(TargetOp::getRoundingModeAttrName(),
                         convertArithRoundingModeAttrToLLVM(arithAttr));
     }
+    // Constrained intrinsics do not support fastmath flags. Remove the
+    // arith fastmath attribute if present.
+    if constexpr (SourceOp::template hasTrait<
+                      arith::ArithFastMathInterface::Trait>())
+      convertedAttr.erase(srcOp.getFastMathAttrName());
     convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
                       getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
   }
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 45cb3cecef3d8..864c947c005fa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -90,6 +90,37 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
                           attr-dict `:` type($result) }];
 }
 
+// Base class for floating point binary operations with an optional rounding
+// mode.
+class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
+                                          list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic,
+      !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                   DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+                  traits)>,
+    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+      DefaultValuedAttr<
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+      OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+    Results<(outs FloatLike:$result)> {
+  let builders = [
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+      build($_builder, $_state, lhs, rhs, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+      "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+      build($_builder, $_state, lhs, rhs, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+  ];
+  let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
+                          (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
+}
+
 // Checks that tensor input and outputs have identical shapes. This is stricker
 // than the verification done in `SameOperandsAndResultShape` that allows for
 // tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
@@ -957,7 +988,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
+def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
   let summary = "floating point addition operation";
   let description = [{
     The `addf` operation takes two operands and returns one result, each of
@@ -965,6 +996,9 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -976,10 +1010,10 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
 
     // Tensor addition.
     %x = arith.addf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar addition with rounding mode.
+    %a = arith.addf %b, %c to_nearest_even : f64
+    ```
   }];
   let hasFolder = 1;
 }
@@ -988,7 +1022,7 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
 // SubFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
+def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
   let summary = "floating point subtraction operation";
   let description = [{
     The `subf` operation takes two operands and returns one result, each of
@@ -996,6 +1030,9 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -1007,10 +1044,10 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
 
     // Tensor subtraction.
     %x = arith.subf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar subtraction with rounding mode.
+    %a = arith.subf %b, %c downward : f64
+    ```
   }];
   let hasFolder = 1;
 }
@@ -1139,7 +1176,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
+def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
   let summary = "floating point multiplication operation";
   let description = [{
     The `mulf` operation takes two operands and returns one result, each of
@@ -1147,6 +1184,9 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -1158,10 +1198,10 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
 
     // Tensor pointwise multiplication.
     %x = arith.mulf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar multiplication with rounding mode.
+    %a = arith.mulf %b, %c upward : f64
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -1171,8 +1211,27 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
 // DivFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
+def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
   let summary = "floating point division operation";
+  let description = [{
+    The `divf` operation takes two operands and returns one result, each of
+    these is required to be the same type. This type may be a floating point
+    scalar type, a vector whose element type is a floating point type, or a
+    floating point tensor.
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
+    Example:
+
+    ```mlir
+    // Scalar division.
+    %a = arith.divf %b, %c : f64
+
+    // Scalar division with rounding mode.
+    %a = arith.divf %b, %c toward_zero : f64
+    ```
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
@@ -1181,11 +1240,14 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
 // RemFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
   let summary = "floating point division remainder operation";
   let description = [{
     Returns the floating point division remainder.
     The remainder has the same sign as the dividend (lhs operand).
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
   }];
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index a0346ec6f4fb6..9aba2b42926ce 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -81,9 +81,13 @@ struct IdentityBitcastLowering final
 //===----------------------------------------------------------------------===//
 
 using AddFOpLowering =
-    VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using AddIOpLowering =
     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -91,9 +95,13 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
 using DivFOpLowering =
-    VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
@@ -139,9 +147,13 @@ using MinSIOpLowering =
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
 using MulFOpLowering =
-    VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using MulIOpLowering =
     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -151,9 +163,13 @@ using NegFOpLowering =
                                /*FailOnUnsupportedFP=*/true>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
-    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -170,9 +186,13 @@ using ShRUIOpLowering =
 using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
 using SubFOpLowering =
-    VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -690,6 +710,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
   // clang-format off
   patterns.add<
     AddFOpLowering,
+    ConstrainedAddFOpLowering,
     AddIOpLowering,
     AndIOpLowering,
     AddUIExtendedOpLowering,
@@ -698,6 +719,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     CmpFOpLowering,
     CmpIOpLowering,
     DivFOpLowering,
+    ConstrainedDivFOpLowering,
     DivSIOpLowering,
     DivUIOpLowering,
     ExtFOpLowering,
@@ -717,12 +739,14 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     MinSIOpLowering,
     MinUIOpLowering,
     MulFOpLowering,
+    ConstrainedMulFOpLowering,
     MulIOpLowering,
     MulSIExtendedOpLowering,
     MulUIExtendedOpLowering,
     NegFOpLowering,
     OrIOpLowering,
     RemFOpLowering,
+    ConstrainedRemFOpLowering,
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
@@ -732,6 +756,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     ShRUIOpLowering,
     SIToFPOpLowering,
     SubFOpLowering,
+    ConstrainedSubFOpLowering,
     SubIOpLowering,
     TruncFOpLowering,
     ConstrainedTruncFOpLowering,
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..b899220f2e9af 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
 
     Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
     Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
-    Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
-                                                realRhs, fmf.getValue());
+    Value resultReal =
+        BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
     Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
     Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
-    Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
-                                                imagRhs, fmf.getValue());
+    Value resultImag =
+        BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..11b3aabcbfeb4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -120,7 +120,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
     auto one =
         arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
-    return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
+    return arith::DivFOp::create(rewriter, loc, one, args[0]);
   }
 
   // tosa::MulOp
@@ -140,8 +140,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                           "Cannot have shift value for float");
         return nullptr;
       }
-      return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
-                                   args[1]);
+      return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
     }
 
     if (isa<IntegerType>(elementTy)) {
@@ -538,8 +537,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
     auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
     auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
-    auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
-    return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
+    auto added = arith::AddFOp::create(rewriter, loc, exp, one);
+    return arith::DivFOp::create(rewriter, loc, one, added);
   }
 
   // tosa::CastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..488dac54569d5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -449,10 +449,10 @@ def UIToFPOfExtUI :
 //===----------------------------------------------------------------------===//
 
 // mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and rounding mode of original mulf)
 def MulFOfNegF :
-    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_MulFOp $x, $y, $fmf),
+    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+        (Arith_MulFOp $x, $y, $fmf, $rm),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 //===----------------------------------------------------------------------===//
@@ -460,10 +460,10 @@ def MulFOfNegF :
 //===----------------------------------------------------------------------===//
 
 // divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and rounding mode of original divf)
 def DivFOfNegF :
-    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_DivFOp $x, $y, $fmf),
+    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+        (Arith_DivFOp $x, $y, $fmf, $rm),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 #endif // ARITH_PATTERNS
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..b00ae7bfc4724 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1107,9 +1107,14 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a + b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.add(b, rm ? convertAri...
[truncated]

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 25, 2026

🐧 Linux x64 Test Results

  • 7818 tests passed
  • 606 tests skipped

✅ The build succeeded and all tests passed.

Copy link
Copy Markdown
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good cleanup, thanks! Some comments inline.

Comment thread mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Comment thread mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated
Comment thread mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated
Comment thread mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Comment thread mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp Outdated
Comment thread mlir/lib/Dialect/Arith/IR/ArithOps.cpp Outdated
floating point tensor.

If the value cannot be exactly represented, it is rounded using the
provided rounding mode or the default one if no rounding mode is provided.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would specify what is the default rounding mode in the doc.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is referring to the "default LLVM floating-point environment" now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should just refer directly to the "default LLVM floating-point environment", since the arithmetic dialect isn't tied to LLVM right now (what if you're targeting a non-LLVM lowering?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you phrase it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my thought would be to go up into the arith documentation and describe the default floating point environment we use for internal purposes like constant-folding and note that this is the default used by LLVM, but observe that, as in many other languages (C, it looks like Vulkan) leaving the floating point rounding mode unspecified gives you whatever the platform default is.

So basically we'll be doing round-nearest-ties-to-even, no signalling NaNs, denormals enabled, etc. inside the arith dialect, but we're not going to promise that that's what'll happen post-export (because some targets don't make that promise either and it's fine, partly because everyone basically agrees anyway)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right: we need the arith dialect documentation to be self contained basically. We may refer to LLVM in some way, but we can't assume it in "the one environment" where arith operates right now.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the top-level documentation of the arith dialect based on @krzysz00's comment.

@krzysz00 @joker-eph @rengolin Can you double-check the wording? If this is not what you expected, please comment with the exact wording that you'd like to see here.

Copy link
Copy Markdown
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor notes on semantics (/let's add a decoration to modules allowing for fpenv manipulation) but overall it's good to see this

Comment thread mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated
Comment thread mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated
@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_rounding branch from 3527ba6 to 11a73a1 Compare April 7, 2026 13:18
@matthias-springer
Copy link
Copy Markdown
Member Author

I removed remf from this commit because the APFloat API does not have rounding mode support for that operation.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_rounding branch from c00fb4a to 2b7ce18 Compare April 7, 2026 14:06
@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_rounding branch from 2b7ce18 to f7f1b3f Compare April 16, 2026 15:02
Copy link
Copy Markdown
Contributor

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me, thanks!

@matthias-springer matthias-springer merged commit b2317cc into main Apr 17, 2026
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/arith_rounding branch April 17, 2026 08:13
alexfh pushed a commit to alexfh/llvm-project that referenced this pull request Apr 18, 2026
…llvm#188458)

Add rounding mode flags for `addf`, `subf`, `mulf`, `divf`. This
addresses a TODO in the op description.

The folder now takes into account the specified rounding mode. If no
rounding mode is specified, the folders/canonicalizations default to
`rmNearestTiesToEven`. (This behavior has not changed.) This is
documented in the top-level arith dialect documentation. The default
arith rounding mode applies only to "internal" transformations such as
foldings/canonicalizations. In case of an unspecified explicit rounding
mode, the runtime behavior is up to the target backend.

Also add a lowering to LLVM intrinsics such as
`llvm.intr.experimental.constrained.fadd`.

Assisted-by: claude-4.6-opus-high
matthias-springer added a commit that referenced this pull request Apr 25, 2026
Rounding modes have recently been added for `arith` FP operations
(#188458). This commit adds rounding modes to `math.fma`, following the
same design as for `arith` FP operations.

If a rounding mode is present, the LLVM lowering produces
`llvm.intr.experimental.constrained.fma`.

In the absence of a rounding mode, the rounding behavior is deferred to
the target backend.

Assisted-by: claude-opus-4.7-thinking-high
yingopq pushed a commit to yingopq/llvm-project that referenced this pull request Apr 29, 2026
Rounding modes have recently been added for `arith` FP operations
(llvm#188458). This commit adds rounding modes to `math.fma`, following the
same design as for `arith` FP operations.

If a rounding mode is present, the LLVM lowering produces
`llvm.intr.experimental.constrained.fma`.

In the absence of a rounding mode, the rounding behavior is deferred to
the target backend.

Assisted-by: claude-opus-4.7-thinking-high
KHicketts pushed a commit to KHicketts/llvm-project that referenced this pull request Apr 30, 2026
…llvm#188458)

Add rounding mode flags for `addf`, `subf`, `mulf`, `divf`. This
addresses a TODO in the op description.

The folder now takes into account the specified rounding mode. If no
rounding mode is specified, the folders/canonicalizations default to
`rmNearestTiesToEven`. (This behavior has not changed.) This is
documented in the top-level arith dialect documentation. The default
arith rounding mode applies only to "internal" transformations such as
foldings/canonicalizations. In case of an unspecified explicit rounding
mode, the runtime behavior is up to the target backend.

Also add a lowering to LLVM intrinsics such as
`llvm.intr.experimental.constrained.fadd`.

Assisted-by: claude-4.6-opus-high
KHicketts pushed a commit to KHicketts/llvm-project that referenced this pull request Apr 30, 2026
Rounding modes have recently been added for `arith` FP operations
(llvm#188458). This commit adds rounding modes to `math.fma`, following the
same design as for `arith` FP operations.

If a rounding mode is present, the LLVM lowering produces
`llvm.intr.experimental.constrained.fma`.

In the absence of a rounding mode, the rounding behavior is deferred to
the target backend.

Assisted-by: claude-opus-4.7-thinking-high
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants