diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h index 39dc7c100d6f7..10467cc96f9c6 100644 --- a/llvm/include/llvm/IR/ConstantFPRange.h +++ b/llvm/include/llvm/IR/ConstantFPRange.h @@ -230,6 +230,10 @@ class [[nodiscard]] ConstantFPRange { /// Return a new range representing the possible values resulting /// from a subtraction of a value in this range and a value in \p Other. LLVM_ABI ConstantFPRange sub(const ConstantFPRange &Other) const; + + /// Flush denormal values to zero according to the specified mode. + /// For dynamic mode, we return the union of all possible results. + LLVM_ABI void flushDenormals(DenormalMode::DenormalModeKind Mode); }; inline raw_ostream &operator<<(raw_ostream &OS, const ConstantFPRange &CR) { diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp index 51d2e215c980c..e9c058ee8ec25 100644 --- a/llvm/lib/IR/ConstantFPRange.cpp +++ b/llvm/lib/IR/ConstantFPRange.cpp @@ -8,6 +8,7 @@ #include "llvm/IR/ConstantFPRange.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -506,3 +507,24 @@ ConstantFPRange ConstantFPRange::sub(const ConstantFPRange &Other) const { // fsub X, Y = fadd X, (fneg Y) return add(Other.negate()); } + +void ConstantFPRange::flushDenormals(DenormalMode::DenormalModeKind Mode) { + if (Mode == DenormalMode::IEEE) + return; + FPClassTest Class = classify(); + if (!(Class & fcSubnormal)) + return; + + auto &Sem = getSemantics(); + // PreserveSign: PosSubnormal -> PosZero, NegSubnormal -> NegZero + // PositiveZero: PosSubnormal -> PosZero, NegSubnormal -> PosZero + // Dynamic: PosSubnormal -> PosZero, NegSubnormal -> NegZero/PosZero + bool ZeroLowerNegative = + Mode != DenormalMode::PositiveZero && (Class & fcNegSubnormal); + bool ZeroUpperNegative = + Mode == DenormalMode::PreserveSign && !(Class & fcPosSubnormal); + assert((ZeroLowerNegative || !ZeroUpperNegative) && + "ZeroLower is greater than ZeroUpper."); + Lower = minnum(Lower, APFloat::getZero(Sem, ZeroLowerNegative)); + Upper = maxnum(Upper, APFloat::getZero(Sem, ZeroUpperNegative)); +} diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp index cf9b31cf88480..2431db90a40bd 100644 --- a/llvm/unittests/IR/ConstantFPRangeTest.cpp +++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp @@ -8,6 +8,7 @@ #include "llvm/IR/ConstantFPRange.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "gtest/gtest.h" @@ -1065,4 +1066,70 @@ TEST_F(ConstantFPRangeTest, sub) { #endif } +TEST_F(ConstantFPRangeTest, flushDenormals) { + const fltSemantics &FP8Sem = APFloat::Float8E4M3(); + APFloat NormalVal = APFloat::getSmallestNormalized(FP8Sem); + APFloat Subnormal1 = NormalVal; + Subnormal1.next(/*nextDown=*/true); + APFloat Subnormal2 = APFloat::getSmallest(FP8Sem); + APFloat ZeroVal = APFloat::getZero(FP8Sem); + APFloat EdgeValues[8] = {-NormalVal, -Subnormal1, -Subnormal2, -ZeroVal, + ZeroVal, Subnormal2, Subnormal1, NormalVal}; + constexpr DenormalMode::DenormalModeKind Modes[4] = { + DenormalMode::IEEE, DenormalMode::PreserveSign, + DenormalMode::PositiveZero, DenormalMode::Dynamic}; + for (uint32_t I = 0; I != 8; ++I) { + for (uint32_t J = I; J != 8; ++J) { + ConstantFPRange OriginCR = + ConstantFPRange::getNonNaN(EdgeValues[I], EdgeValues[J]); + for (auto Mode : Modes) { + StringRef ModeName = denormalModeKindName(Mode); + ConstantFPRange FlushedCR = OriginCR; + FlushedCR.flushDenormals(Mode); + + ConstantFPRange Expected = ConstantFPRange::getEmpty(FP8Sem); + auto CheckFlushedV = [&](const APFloat &V, const APFloat &FlushedV) { + EXPECT_TRUE(FlushedCR.contains(FlushedV)) + << "Wrong result for flushDenormal(" << V << ", " << ModeName + << "). The result " << FlushedCR << " should contain " + << FlushedV; + if (!Expected.contains(FlushedV)) + Expected = Expected.unionWith(ConstantFPRange(FlushedV)); + }; + EnumerateValuesInConstantFPRange( + OriginCR, + [&](const APFloat &V) { + if (V.isDenormal()) { + switch (Mode) { + case DenormalMode::IEEE: + break; + case DenormalMode::PreserveSign: + CheckFlushedV(V, APFloat::getZero(FP8Sem, V.isNegative())); + break; + case DenormalMode::PositiveZero: + CheckFlushedV(V, APFloat::getZero(FP8Sem)); + break; + case DenormalMode::Dynamic: + // PreserveSign + CheckFlushedV(V, APFloat::getZero(FP8Sem, V.isNegative())); + // PositiveZero + CheckFlushedV(V, APFloat::getZero(FP8Sem)); + break; + default: + llvm_unreachable("unknown denormal mode"); + } + } + // It is not mandated that flushing to zero occurs. + CheckFlushedV(V, V); + }, + /*IgnoreNaNPayload=*/true); + EXPECT_EQ(FlushedCR, Expected) + << "Suboptimal result for flushDenormal(" << OriginCR << ", " + << ModeName << "). Expected " << Expected << ", but got " + << FlushedCR; + } + } + } +} + } // anonymous namespace