Skip to content

Commit

Permalink
[APFloat] Add APFloat semantic support for TF32
Browse files Browse the repository at this point in the history
This diff adds APFloat support for a semantic that matches the TF32 data type
used by some accelerators (most notably GPUs from both NVIDIA and AMD).

For more information on the TF32 data type, see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
Some intrinsics that support the TF32 data type were added in https://reviews.llvm.org/D122044.

For some discussion on supporting common semantics in `APFloat`, see similar
efforts for 8-bit formats at https://reviews.llvm.org/D146441, as well as
https://discourse.llvm.org/t/rfc-adding-the-amd-graphcore-maybe-others-float8-formats-to-apfloat/67969.

A subsequent diff will extend MLIR to use this data type. (Those changes are
not part of this diff to simplify the review process.)

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D151923
  • Loading branch information
jfurtek authored and joker-eph committed Jun 23, 2023
1 parent 2764322 commit 55c2211
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 1 deletion.
1 change: 1 addition & 0 deletions clang/lib/AST/MicrosoftMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
case APFloat::S_Float8E4M3B11FNUZ:
case APFloat::S_FloatTF32:
llvm_unreachable("Tried to mangle unexpected APFloat semantics");
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ struct APFloatBase {
// This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E4M3B11FNUZ,
// Floating point number that occupies 32 bits or less of storage, providing
// improved range compared to half (16-bit) formats, at (potentially)
// greater throughput than single precision (32-bit) formats.
S_FloatTF32,

S_x87DoubleExtended,
S_MaxSemantics = S_x87DoubleExtended,
Expand All @@ -203,6 +207,7 @@ struct APFloatBase {
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
static const fltSemantics &FloatTF32() LLVM_READNONE;
static const fltSemantics &x87DoubleExtended() LLVM_READNONE;

/// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
Expand Down Expand Up @@ -605,6 +610,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
APInt convertFloatTF32APFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
template <const fltSemantics &S> void initFromIEEEAPInt(const APInt &api);
void initFromHalfAPInt(const APInt &api);
Expand All @@ -619,6 +625,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
void initFromFloatTF32APInt(const APInt &api);

void assign(const IEEEFloat &);
void copySignificand(const IEEEFloat &);
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ static constexpr fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static constexpr fltSemantics semBogus = {0, 0, 0, 0};

Expand Down Expand Up @@ -203,6 +204,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E4M3FNUZ();
case S_Float8E4M3B11FNUZ:
return Float8E4M3B11FNUZ();
case S_FloatTF32:
return FloatTF32();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
Expand Down Expand Up @@ -233,6 +236,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E4M3FNUZ;
else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
return S_Float8E4M3B11FNUZ;
else if (&Sem == &llvm::APFloat::FloatTF32())
return S_FloatTF32;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
Expand All @@ -254,6 +259,7 @@ const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
return semFloat8E4M3B11FNUZ;
}
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
Expand Down Expand Up @@ -3599,6 +3605,11 @@ APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat8E4M3B11FNUZ>();
}

APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloatTF32>();
}

// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
Expand Down Expand Up @@ -3637,6 +3648,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
return convertFloat8E4M3B11FNUZAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
return convertFloatTF32APFloatToAPInt();

assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
Expand Down Expand Up @@ -3840,6 +3854,10 @@ void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3B11FNUZ>(api);
}

void IEEEFloat::initFromFloatTF32APInt(const APInt &api) {
initFromIEEEAPInt<semFloatTF32>(api);
}

/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
Expand Down Expand Up @@ -3867,6 +3885,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E4M3FNUZAPInt(api);
if (Sem == &semFloat8E4M3B11FNUZ)
return initFromFloat8E4M3B11FNUZAPInt(api);
if (Sem == &semFloatTF32)
return initFromFloatTF32APInt(api);

llvm_unreachable(nullptr);
}
Expand Down
115 changes: 114 additions & 1 deletion llvm/unittests/ADT/APFloatTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,26 @@ TEST(APFloatTest, Denormal) {
EXPECT_TRUE(T.isDenormal());
EXPECT_EQ(fcPosSubnormal, T.classify());
}

// Test TF32
{
const char *MinNormalStr = "1.17549435082228750797e-38";
EXPECT_FALSE(APFloat(APFloat::FloatTF32(), MinNormalStr).isDenormal());
EXPECT_FALSE(APFloat(APFloat::FloatTF32(), 0).isDenormal());

APFloat Val2(APFloat::FloatTF32(), 2);
APFloat T(APFloat::FloatTF32(), MinNormalStr);
T.divide(Val2, rdmd);
EXPECT_TRUE(T.isDenormal());
EXPECT_EQ(fcPosSubnormal, T.classify());

const char *NegMinNormalStr = "-1.17549435082228750797e-38";
EXPECT_FALSE(APFloat(APFloat::FloatTF32(), NegMinNormalStr).isDenormal());
APFloat NegT(APFloat::FloatTF32(), NegMinNormalStr);
NegT.divide(Val2, rdmd);
EXPECT_TRUE(NegT.isDenormal());
EXPECT_EQ(fcNegSubnormal, NegT.classify());
}
}

TEST(APFloatTest, IsSmallestNormalized) {
Expand Down Expand Up @@ -1350,6 +1370,16 @@ TEST(APFloatTest, makeNaN) {
{ 0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, true, 0xaaULL },
{ 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, false, 0xaaULL },
{ 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, true, 0xaaULL },
{ 0x3fe00ULL, APFloat::FloatTF32(), false, false, 0x00000000ULL },
{ 0x7fe00ULL, APFloat::FloatTF32(), false, true, 0x00000000ULL },
{ 0x3feaaULL, APFloat::FloatTF32(), false, false, 0xaaULL },
{ 0x3ffaaULL, APFloat::FloatTF32(), false, false, 0xdaaULL },
{ 0x3ffaaULL, APFloat::FloatTF32(), false, false, 0xfdaaULL },
{ 0x3fd00ULL, APFloat::FloatTF32(), true, false, 0x00000000ULL },
{ 0x7fd00ULL, APFloat::FloatTF32(), true, true, 0x00000000ULL },
{ 0x3fcaaULL, APFloat::FloatTF32(), true, false, 0xaaULL },
{ 0x3fdaaULL, APFloat::FloatTF32(), true, false, 0xfaaULL },
{ 0x3fdaaULL, APFloat::FloatTF32(), true, false, 0x1aaULL },
// clang-format on
};

Expand Down Expand Up @@ -1780,6 +1810,8 @@ TEST(APFloatTest, getLargest) {
APFloat::getLargest(APFloat::Float8E5M2FNUZ()).convertToDouble());
EXPECT_EQ(
30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
EXPECT_EQ(3.40116213421e+38f,
APFloat::getLargest(APFloat::FloatTF32()).convertToFloat());
}

TEST(APFloatTest, getSmallest) {
Expand Down Expand Up @@ -1831,6 +1863,13 @@ TEST(APFloatTest, getSmallest) {
EXPECT_TRUE(test.isFiniteNonZero());
EXPECT_TRUE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));

test = APFloat::getSmallest(APFloat::FloatTF32(), true);
expected = APFloat(APFloat::FloatTF32(), "-0x0.004p-126");
EXPECT_TRUE(test.isNegative());
EXPECT_TRUE(test.isFiniteNonZero());
EXPECT_TRUE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
}

TEST(APFloatTest, getSmallestNormalized) {
Expand Down Expand Up @@ -1905,6 +1944,14 @@ TEST(APFloatTest, getSmallestNormalized) {
EXPECT_FALSE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
EXPECT_TRUE(test.isSmallestNormalized());

test = APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
expected = APFloat(APFloat::FloatTF32(), "0x1p-126");
EXPECT_FALSE(test.isNegative());
EXPECT_TRUE(test.isFiniteNonZero());
EXPECT_FALSE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
EXPECT_TRUE(test.isSmallestNormalized());
}

TEST(APFloatTest, getZero) {
Expand Down Expand Up @@ -1936,7 +1983,9 @@ TEST(APFloatTest, getZero) {
{&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
{&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1},
{&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1},
{&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1}};
{&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1},
{&APFloat::FloatTF32(), false, true, {0, 0}, 1},
{&APFloat::FloatTF32(), true, true, {0x40000ULL, 0}, 1}};
const unsigned NumGetZeroTests = std::size(GetZeroTest);
for (unsigned i = 0; i < NumGetZeroTests; ++i) {
APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
Expand Down Expand Up @@ -6229,6 +6278,34 @@ TEST(APFloatTest, Float8E4M3FNUZToDouble) {
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, FloatTF32ToDouble) {
APFloat One(APFloat::FloatTF32(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
APFloat PosLargest = APFloat::getLargest(APFloat::FloatTF32(), false);
EXPECT_EQ(3.401162134214653489792616e+38, PosLargest.convertToDouble());
APFloat NegLargest = APFloat::getLargest(APFloat::FloatTF32(), true);
EXPECT_EQ(-3.401162134214653489792616e+38, NegLargest.convertToDouble());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
EXPECT_EQ(1.1754943508222875079687e-38, PosSmallest.convertToDouble());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::FloatTF32(), true);
EXPECT_EQ(-1.1754943508222875079687e-38, NegSmallest.convertToDouble());

APFloat SmallestDenorm = APFloat::getSmallest(APFloat::FloatTF32(), false);
EXPECT_EQ(1.1479437019748901445007e-41, SmallestDenorm.convertToDouble());
APFloat LargestDenorm(APFloat::FloatTF32(), "0x1.FF8p-127");
EXPECT_EQ(/*0x1.FF8p-127*/ 1.1743464071203126178242e-38,
LargestDenorm.convertToDouble());

APFloat PosInf = APFloat::getInf(APFloat::FloatTF32());
EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
APFloat NegInf = APFloat::getInf(APFloat::FloatTF32(), true);
EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
APFloat QNaN = APFloat::getQNaN(APFloat::FloatTF32());
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, Float8E5M2FNUZToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2FNUZ());
APFloat PosZeroToFloat(PosZero.convertToFloat());
Expand Down Expand Up @@ -6473,4 +6550,40 @@ TEST(APFloatTest, Float8E4M3FNToFloat) {
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, FloatTF32ToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::FloatTF32());
APFloat PosZeroToFloat(PosZero.convertToFloat());
EXPECT_TRUE(PosZeroToFloat.isPosZero());
APFloat NegZero = APFloat::getZero(APFloat::FloatTF32(), true);
APFloat NegZeroToFloat(NegZero.convertToFloat());
EXPECT_TRUE(NegZeroToFloat.isNegZero());

APFloat One(APFloat::FloatTF32(), "1.0");
EXPECT_EQ(1.0F, One.convertToFloat());
APFloat Two(APFloat::FloatTF32(), "2.0");
EXPECT_EQ(2.0F, Two.convertToFloat());

APFloat PosLargest = APFloat::getLargest(APFloat::FloatTF32(), false);
EXPECT_EQ(3.40116213421e+38F, PosLargest.convertToFloat());

APFloat NegLargest = APFloat::getLargest(APFloat::FloatTF32(), true);
EXPECT_EQ(-3.40116213421e+38F, NegLargest.convertToFloat());

APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
EXPECT_EQ(/*0x1.p-126*/ 1.1754943508222875e-38F,
PosSmallest.convertToFloat());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::FloatTF32(), true);
EXPECT_EQ(/*-0x1.p-126*/ -1.1754943508222875e-38F,
NegSmallest.convertToFloat());

APFloat SmallestDenorm = APFloat::getSmallest(APFloat::FloatTF32(), false);
EXPECT_TRUE(SmallestDenorm.isDenormal());
EXPECT_EQ(0x0.004p-126, SmallestDenorm.convertToFloat());

APFloat QNaN = APFloat::getQNaN(APFloat::FloatTF32());
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

} // namespace

0 comments on commit 55c2211

Please sign in to comment.