diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 6f368604df65a..d6a98a667da32 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1451,6 +1451,17 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { static FailureOr convertFloatValue( APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { + // Reject special values that are not representable in the target type before + // calling APFloat::convert, which would llvm_unreachable on them. + using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior; + if (sourceValue.isInfinity() && + (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly || + targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)) + return failure(); + if (sourceValue.isNaN() && + targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly) + return failure(); + bool losesInfo = false; auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo); if (losesInfo || status != APFloat::opOK) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 26b04e4209a43..643e4e076e7c6 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -3527,3 +3527,29 @@ func.func @cmpi_dynamic_shape_no_fold(%arg0: tensor) -> tensor { return %0 : tensor } +// ----- + +// arith.truncf of infinity to a FiniteOnly float type (f4E2M1FN) must not fold, +// since the type has no infinity representation. Previously this would crash +// inside APFloat::convert with llvm_unreachable("semantics don't support inf!"). + +// CHECK-LABEL: @truncf_inf_to_finite_only_no_fold +// CHECK: arith.truncf +func.func @truncf_inf_to_finite_only_no_fold() -> f4E2M1FN { + %inf = arith.constant 0x7F800000 : f32 + %result = arith.truncf %inf : f32 to f4E2M1FN + return %result : f4E2M1FN +} + +// ----- + +// arith.truncf of negative infinity to a FiniteOnly float type must not fold. + +// CHECK-LABEL: @truncf_neg_inf_to_finite_only_no_fold +// CHECK: arith.truncf +func.func @truncf_neg_inf_to_finite_only_no_fold() -> f4E2M1FN { + %neg_inf = arith.constant 0xFF800000 : f32 + %result = arith.truncf %neg_inf : f32 to f4E2M1FN + return %result : f4E2M1FN +} +