Skip to content

[MLIR][Arith] FastMath extf conversion without NaN checks#180926

Merged
rengolin merged 1 commit into
llvm:mainfrom
rengolin:nnan-extf
Feb 11, 2026
Merged

[MLIR][Arith] FastMath extf conversion without NaN checks#180926
rengolin merged 1 commit into
llvm:mainfrom
rengolin:nnan-extf

Conversation

@rengolin
Copy link
Copy Markdown
Member

This PR allows the expand op converter to consider the NoNaN fastmath attribute to disable the runtime checks for NaNs in E8M0 types. Default behaviour is still the same.

The OCP document provides all-ones as NaN for E8M0, but for pre-MX I8 quantization, the checks for NaNs are prohibitively expensive, especially if the hardware doesn't have native support for that type.

This PR allows the expand op converter to consider the NoNaN fastmath attribute to disable the runtime checks for NaNs in E8M0 types. Default behaviour is still the same.

The OCP document provides all-ones as NaN for E8M0, but for pre-MX I8 quantization, the checks for NaNs are prohibitively expensive if the hardware doesn't have native support for that type.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 11, 2026

@llvm/pr-subscribers-mlir

Author: Renato Golin (rengolin)

Changes

This PR allows the expand op converter to consider the NoNaN fastmath attribute to disable the runtime checks for NaNs in E8M0 types. Default behaviour is still the same.

The OCP document provides all-ones as NaN for E8M0, but for pre-MX I8 quantization, the checks for NaNs are prohibitively expensive, especially if the hardware doesn't have native support for that type.


Full diff: https://github.com/llvm/llvm-project/pull/180926.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+15-8)
  • (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+20-5)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index c4e81e5dbed21..46f8c1037d47b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -452,18 +452,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
     Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
-    // create constants for NaNs
-    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
     Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
     Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
     Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
 
-    Value isNan =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+    // If FastMathFlag allows no NaN checks, skip it
+    auto fastMath = op.getFastmathAttr();
+    bool NoNaN = fastMath
+                     ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
+                           arith::FastMathFlags::nnan
+                     : false;
+    if (!NoNaN) {
+      Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+      Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+      Value isNan =
+          arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
+      // select for NaNs
+      f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+    }
+
     Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
     if (resultETy.getIntOrFloatBitWidth() < 32) {
       result = arith::TruncFOp::create(b, resultTy, result, nullptr,
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 61e22af31f030..75c4de2168761 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -383,11 +383,11 @@ func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
 
 // CHECK-LABEL: @extf_f8E8M0FNU_to_f32
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
-// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
-// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
-// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
 // CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
 // CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
 // CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
@@ -395,6 +395,21 @@ func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
 
 // -----
 
+func.func @extf_f8E8M0FNU_to_f32_no_nan(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 fastmath<nnan> : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f8E8M0FNU_to_f32_no_nan
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SHLI]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
 func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
     %0 = arith.extf %arg0 : f8E8M0FNU to f16
     return %0 : f16
@@ -402,11 +417,11 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
 
 // CHECK-LABEL: @extf_f8E8M0FNU_to_f16
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
-// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
-// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
 // CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
 // CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
 // CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 11, 2026

@llvm/pr-subscribers-mlir-arith

Author: Renato Golin (rengolin)

Changes

This PR allows the expand op converter to consider the NoNaN fastmath attribute to disable the runtime checks for NaNs in E8M0 types. Default behaviour is still the same.

The OCP document provides all-ones as NaN for E8M0, but for pre-MX I8 quantization, the checks for NaNs are prohibitively expensive, especially if the hardware doesn't have native support for that type.


Full diff: https://github.com/llvm/llvm-project/pull/180926.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+15-8)
  • (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+20-5)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index c4e81e5dbed21..46f8c1037d47b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -452,18 +452,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
     Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
-    // create constants for NaNs
-    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
     Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
     Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
     Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
 
-    Value isNan =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+    // If FastMathFlag allows no NaN checks, skip it
+    auto fastMath = op.getFastmathAttr();
+    bool NoNaN = fastMath
+                     ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
+                           arith::FastMathFlags::nnan
+                     : false;
+    if (!NoNaN) {
+      Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+      Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+      Value isNan =
+          arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
+      // select for NaNs
+      f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+    }
+
     Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
     if (resultETy.getIntOrFloatBitWidth() < 32) {
       result = arith::TruncFOp::create(b, resultTy, result, nullptr,
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 61e22af31f030..75c4de2168761 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -383,11 +383,11 @@ func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
 
 // CHECK-LABEL: @extf_f8E8M0FNU_to_f32
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
-// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
-// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
-// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
 // CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
 // CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
 // CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
@@ -395,6 +395,21 @@ func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
 
 // -----
 
+func.func @extf_f8E8M0FNU_to_f32_no_nan(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 fastmath<nnan> : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f8E8M0FNU_to_f32_no_nan
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SHLI]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
 func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
     %0 = arith.extf %arg0 : f8E8M0FNU to f16
     return %0 : f16
@@ -402,11 +417,11 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
 
 // CHECK-LABEL: @extf_f8E8M0FNU_to_f16
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
-// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
-// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
 // CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
 // CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
 // CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32

@rengolin rengolin merged commit 81e0de2 into llvm:main Feb 11, 2026
13 checks passed
@rengolin rengolin deleted the nnan-extf branch February 11, 2026 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants