Skip to content

Commit

Permalink
[APFloat] Add E4M3B11FNUZ
Browse files Browse the repository at this point in the history
X. Sun et al. (https://dl.acm.org/doi/10.5555/3454287.3454728) published
a paper showing that an FP format with 4 bits of exponent, 3 bits of
significand and an exponent bias of 11 would work quite well for ML
applications.

Google hardware supports a variant of this format where 0x80 is used to
represent NaN, as in the Float8E4M3FNUZ format. Just like the
Float8E4M3FNUZ format, this format does not support -0 and values which
would map to it will become +0.

This format is proposed for inclusion in OpenXLA's StableHLO dialect: openxla/stablehlo#1308

As part of inclusion in that dialect, APFloat needs to know how to
handle this format.
  • Loading branch information
majnemer committed Mar 20, 2023
1 parent edc0355 commit 8406251
Show file tree
Hide file tree
Showing 23 changed files with 366 additions and 150 deletions.
1 change: 1 addition & 0 deletions clang/lib/AST/MicrosoftMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_Float8E4M3FN:
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
case APFloat::S_Float8E4M3B11FNUZ:
llvm_unreachable("Tried to mangle unexpected APFloat semantics");
}

Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ struct APFloatBase {
// This format's exponent bias is 8, instead of the 7 (2 ** (4 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E4M3FNUZ,
// 8-bit floating point number mostly following IEEE-754 conventions
// and bit layout S1E4M3 with expanded range and with no infinity or signed
// zero.
// NaN is represnted as negative zero. (FN -> Finite, UZ -> unsigned zero).
// This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E4M3B11FNUZ,

S_x87DoubleExtended,
S_MaxSemantics = S_x87DoubleExtended,
Expand All @@ -195,6 +202,7 @@ struct APFloatBase {
static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
static const fltSemantics &x87DoubleExtended() LLVM_READNONE;

/// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
Expand Down Expand Up @@ -590,6 +598,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloat8E5M2FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
void initFromHalfAPInt(const APInt &api);
void initFromBFloatAPInt(const APInt &api);
Expand All @@ -602,6 +611,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloat8E5M2FNUZAPInt(const APInt &api);
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);

void assign(const IEEEFloat &);
void copySignificand(const IEEEFloat &);
Expand Down
77 changes: 75 additions & 2 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ enum class fltNonfiniteBehavior {
IEEE754,

// This behavior is present in the Float8ExMyFN* types (Float8E4M3FN,
// Float8E5M2FNUZ, and Float8E4M3FNUZ). There is no representation for Inf,
// and operations that would ordinarily produce Inf produce NaN instead.
// Float8E5M2FNUZ, Float8E4M3FNUZ, and Float8E4M3B11FNUZ). There is no
// representation for Inf, and operations that would ordinarily produce Inf
// produce NaN instead.
// The details of the NaN representation(s) in this form are determined by the
// `fltNanEncoding` enum. We treat all NaNs as quiet, as the available
// encodings do not distinguish between signalling and quiet NaN.
Expand Down Expand Up @@ -138,6 +139,13 @@ struct fltSemantics {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static const fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static const fltSemantics semFloat8E4M3B11FNUZ = {
4,
-10,
4,
8,
fltNonfiniteBehavior::NanOnly,
fltNanEncoding::NegativeZero};
static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static const fltSemantics semBogus = {0, 0, 0, 0};

Expand Down Expand Up @@ -201,6 +209,8 @@ struct fltSemantics {
return Float8E4M3FN();
case S_Float8E4M3FNUZ:
return Float8E4M3FNUZ();
case S_Float8E4M3B11FNUZ:
return Float8E4M3B11FNUZ();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
Expand Down Expand Up @@ -229,6 +239,8 @@ struct fltSemantics {
return S_Float8E4M3FN;
else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
return S_Float8E4M3FNUZ;
else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
return S_Float8E4M3B11FNUZ;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
Expand Down Expand Up @@ -259,6 +271,9 @@ struct fltSemantics {
const fltSemantics &APFloatBase::Float8E4M3FNUZ() {
return semFloat8E4M3FNUZ;
}
const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
return semFloat8E4M3B11FNUZ;
}
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
Expand Down Expand Up @@ -3709,6 +3724,33 @@ APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const {
(mysignificand & 0x7)));
}

APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ);
assert(partCount() == 1);

uint32_t myexponent, mysignificand;

if (isFiniteNonZero()) {
myexponent = exponent + 11; // bias
mysignificand = (uint32_t)*significandParts();
if (myexponent == 1 && !(mysignificand & 0x8))
myexponent = 0; // denormal
} else if (category == fcZero) {
myexponent = 0;
mysignificand = 0;
} else if (category == fcInfinity) {
myexponent = 0;
mysignificand = 0;
} else {
assert(category == fcNaN && "Unknown category!");
myexponent = 0;
mysignificand = (uint32_t)*significandParts();
}

return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) |
(mysignificand & 0x7)));
}

// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
Expand Down Expand Up @@ -3744,6 +3786,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ)
return convertFloat8E4M3FNUZAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
return convertFloat8E4M3B11FNUZAPFloatToAPInt();

assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
Expand Down Expand Up @@ -4077,6 +4122,32 @@ void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) {
}
}

void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
uint32_t i = (uint32_t)*api.getRawData();
uint32_t myexponent = (i >> 3) & 0xf;
uint32_t mysignificand = i & 0x7;

initialize(&semFloat8E4M3B11FNUZ);
assert(partCount() == 1);

sign = i >> 7;
if (myexponent == 0 && mysignificand == 0 && sign == 0) {
makeZero(sign);
} else if (myexponent == 0 && mysignificand == 0 && sign == 1) {
category = fcNaN;
exponent = exponentNaN();
*significandParts() = mysignificand;
} else {
category = fcNormal;
exponent = myexponent - 11; // bias
*significandParts() = mysignificand;
if (myexponent == 0) // denormal
exponent = -10;
else
*significandParts() |= 0x8; // integer bit
}
}

/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
Expand All @@ -4102,6 +4173,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E4M3FNAPInt(api);
if (Sem == &semFloat8E4M3FNUZ)
return initFromFloat8E4M3FNUZAPInt(api);
if (Sem == &semFloat8E4M3B11FNUZ)
return initFromFloat8E4M3B11FNUZAPInt(api);

llvm_unreachable(nullptr);
}
Expand Down
Loading

0 comments on commit 8406251

Please sign in to comment.