diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h index face5da80ddc0..d47f6c02c883d 100644 --- a/llvm/include/llvm/IR/ConstantFPRange.h +++ b/llvm/include/llvm/IR/ConstantFPRange.h @@ -216,6 +216,12 @@ class [[nodiscard]] ConstantFPRange { /// Get the range without infinities. It is useful when we apply ninf flag to /// range of operands/results. LLVM_ABI ConstantFPRange getWithoutInf() const; + + /// Return a new range in the specified format with the specified rounding + /// mode. + LLVM_ABI ConstantFPRange + cast(const fltSemantics &DstSem, + APFloat::roundingMode RM = APFloat::rmNearestTiesToEven) const; }; inline raw_ostream &operator<<(raw_ostream &OS, const ConstantFPRange &CR) { diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp index 2477e22aef085..070e833f4d1c0 100644 --- a/llvm/lib/IR/ConstantFPRange.cpp +++ b/llvm/lib/IR/ConstantFPRange.cpp @@ -326,6 +326,8 @@ std::optional ConstantFPRange::getSignBit() const { } bool ConstantFPRange::operator==(const ConstantFPRange &CR) const { + assert(&getSemantics() == &CR.getSemantics() && + "Should only use the same semantics"); if (MayBeSNaN != CR.MayBeSNaN || MayBeQNaN != CR.MayBeQNaN) return false; return Lower.bitwiseIsEqual(CR.Lower) && Upper.bitwiseIsEqual(CR.Upper); @@ -425,3 +427,20 @@ ConstantFPRange ConstantFPRange::getWithoutInf() const { return ConstantFPRange(std::move(NewLower), std::move(NewUpper), MayBeQNaN, MayBeSNaN); } + +ConstantFPRange ConstantFPRange::cast(const fltSemantics &DstSem, + APFloat::roundingMode RM) const { + bool LosesInfo; + APFloat NewLower = Lower; + APFloat NewUpper = Upper; + // For conservative, return full range if conversion is invalid. + if (NewLower.convert(DstSem, RM, &LosesInfo) == APFloat::opInvalidOp || + NewLower.isNaN()) + return getFull(DstSem); + if (NewUpper.convert(DstSem, RM, &LosesInfo) == APFloat::opInvalidOp || + NewUpper.isNaN()) + return getFull(DstSem); + return ConstantFPRange(std::move(NewLower), std::move(NewUpper), + /*MayBeQNaNVal=*/MayBeQNaN || MayBeSNaN, + /*MayBeSNaNVal=*/false); +} diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp index 5bc516d0dc56c..58a65b9a96ab8 100644 --- a/llvm/unittests/IR/ConstantFPRangeTest.cpp +++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/ConstantFPRange.h" +#include "llvm/ADT/APFloat.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "gtest/gtest.h" @@ -818,4 +819,110 @@ TEST_F(ConstantFPRangeTest, getWithout) { APFloat::getLargest(Sem, /*Negative=*/true), APFloat(3.0))); } +TEST_F(ConstantFPRangeTest, cast) { + const fltSemantics &F16Sem = APFloat::IEEEhalf(); + const fltSemantics &BF16Sem = APFloat::BFloat(); + const fltSemantics &F32Sem = APFloat::IEEEsingle(); + const fltSemantics &F8NanOnlySem = APFloat::Float8E4M3FN(); + // normal -> normal (exact) + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat(1.0), APFloat(2.0)).cast(F32Sem), + ConstantFPRange::getNonNaN(APFloat(1.0f), APFloat(2.0f))); + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat(-2.0f), APFloat(-1.0f)).cast(Sem), + ConstantFPRange::getNonNaN(APFloat(-2.0), APFloat(-1.0))); + // normal -> normal (inexact) + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat(3.141592653589793), + APFloat(6.283185307179586)) + .cast(F32Sem), + ConstantFPRange::getNonNaN(APFloat(3.14159274f), APFloat(6.28318548f))); + // normal -> subnormal + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat(-5e-8), APFloat(5e-8)) + .cast(F16Sem) + .classify(), + fcSubnormal | fcZero); + // normal -> zero + EXPECT_EQ(ConstantFPRange::getNonNaN( + APFloat::getSmallestNormalized(Sem, /*Negative=*/true), + APFloat::getSmallestNormalized(Sem, /*Negative=*/false)) + .cast(F32Sem) + .classify(), + fcZero); + // normal -> inf + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat(-65536.0), APFloat(65536.0)) + .cast(F16Sem), + ConstantFPRange::getNonNaN(F16Sem)); + // nan -> qnan + EXPECT_EQ( + ConstantFPRange::getNaNOnly(Sem, /*MayBeQNaN=*/true, /*MayBeSNaN=*/false) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + EXPECT_EQ( + ConstantFPRange::getNaNOnly(Sem, /*MayBeQNaN=*/false, /*MayBeSNaN=*/true) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + EXPECT_EQ( + ConstantFPRange::getNaNOnly(Sem, /*MayBeQNaN=*/true, /*MayBeSNaN=*/true) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + // For BF16 -> F32, signaling bit is still lost. + EXPECT_EQ(ConstantFPRange::getNaNOnly(BF16Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/true) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + // inf -> nan only (return full set for now) + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat::getInf(Sem, /*Negative=*/true), + APFloat::getInf(Sem, /*Negative=*/false)) + .cast(F8NanOnlySem), + ConstantFPRange::getFull(F8NanOnlySem)); + // other rounding modes + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat::getSmallest(Sem, /*Negative=*/true), + APFloat::getSmallest(Sem, /*Negative=*/false)) + .cast(F32Sem, APFloat::rmTowardNegative), + ConstantFPRange::getNonNaN( + APFloat::getSmallest(F32Sem, /*Negative=*/true), + APFloat::getZero(F32Sem, /*Negative=*/false))); + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat::getSmallest(Sem, /*Negative=*/true), + APFloat::getSmallest(Sem, /*Negative=*/false)) + .cast(F32Sem, APFloat::rmTowardPositive), + ConstantFPRange::getNonNaN( + APFloat::getZero(F32Sem, /*Negative=*/true), + APFloat::getSmallest(F32Sem, /*Negative=*/false))); + EXPECT_EQ( + ConstantFPRange::getNonNaN( + APFloat::getSmallestNormalized(Sem, /*Negative=*/true), + APFloat::getSmallestNormalized(Sem, /*Negative=*/false)) + .cast(F32Sem, APFloat::rmTowardZero), + ConstantFPRange::getNonNaN(APFloat::getZero(F32Sem, /*Negative=*/true), + APFloat::getZero(F32Sem, /*Negative=*/false))); + + EnumerateValuesInConstantFPRange( + ConstantFPRange::getFull(APFloat::Float8E4M3()), + [&](const APFloat &V) { + bool LosesInfo = false; + + APFloat DoubleV = V; + DoubleV.convert(Sem, APFloat::rmNearestTiesToEven, &LosesInfo); + ConstantFPRange DoubleCR = ConstantFPRange(V).cast(Sem); + EXPECT_TRUE(DoubleCR.contains(DoubleV)) + << "Casting " << V << " to double failed. " << DoubleCR + << " doesn't contain " << DoubleV; + + auto &FP4Sem = APFloat::Float4E2M1FN(); + APFloat FP4V = V; + FP4V.convert(FP4Sem, APFloat::rmNearestTiesToEven, &LosesInfo); + ConstantFPRange FP4CR = ConstantFPRange(V).cast(FP4Sem); + EXPECT_TRUE(FP4CR.contains(FP4V)) + << "Casting " << V << " to FP4E2M1FN failed. " << FP4CR + << " doesn't contain " << FP4V; + }, + /*IgnoreNaNPayload=*/true); +} + } // anonymous namespace