Skip to content

Commit

Permalink
[IR][BFloat] Add BFloat IR type
Browse files Browse the repository at this point in the history
Summary:
The BFloat IR type is introduced to provide support for, initially, the BFloat16
datatype introduced with the Armv8.6 architecture (optional from Armv8.2
onwards). It has an 8-bit exponent and a 7-bit mantissa and behaves like an IEEE
754 floating point IR type.

This is part of a patch series upstreaming Armv8.6 features. Subsequent patches
will upstream intrinsics support and C-lang support for BFloat.

Reviewers: SjoerdMeijer, rjmccall, rsmith, liutianle, RKSimon, craig.topper, jfb, LukeGeeson, sdesmalen, deadalnix, ctetreau

Subscribers: hiraditya, llvm-commits, danielkiss, arphaman, kristof.beyls, dexonsmith

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78190
  • Loading branch information
stuij committed May 15, 2020
1 parent 7063a83 commit 8c24f33
Show file tree
Hide file tree
Showing 28 changed files with 354 additions and 129 deletions.
6 changes: 3 additions & 3 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14936,9 +14936,9 @@ static bool actOnOMPReductionKindClause(
if (auto *ComplexTy = OrigType->getAs<ComplexType>())
Type = ComplexTy->getElementType();
if (Type->isRealFloatingType()) {
llvm::APFloat InitValue =
llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type),
/*isIEEE=*/true);
llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue(
Context.getFloatTypeSemantics(Type),
Context.getTypeSize(Type));
Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true,
Type, ELoc);
} else if (Type->isScalarType()) {
Expand Down
8 changes: 8 additions & 0 deletions llvm/docs/BitCodeFormat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,14 @@ TYPE_CODE_HALF Record
The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to
the type table.

TYPE_CODE_BFLOAT Record
^^^^^^^^^^^^^^^^^^^^^

``[BFLOAT]``

The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point)
type to the type table.

TYPE_CODE_FLOAT Record
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
36 changes: 21 additions & 15 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2963,14 +2963,20 @@ Floating-Point Types
* - ``half``
- 16-bit floating-point value

* - ``bfloat``
- 16-bit "brain" floating-point value (7-bit significand). Provides the
same number of exponent bits as ``float``, so that it matches its dynamic
range, but with greatly reduced precision. Used in Intel's AVX-512 BF16
extensions and Arm's ARMv8.6-A extensions, among others.

* - ``float``
- 32-bit floating-point value

* - ``double``
- 64-bit floating-point value

* - ``fp128``
- 128-bit floating-point value (112-bit mantissa)
- 128-bit floating-point value (112-bit significand)

* - ``x86_fp80``
- 80-bit floating-point value (X87)
Expand Down Expand Up @@ -3303,20 +3309,20 @@ number of digits. For example, NaN's, infinities, and other special
values are represented in their IEEE hexadecimal format so that assembly
and disassembly do not cause any bits to change in the constants.

When using the hexadecimal form, constants of types half, float, and
double are represented using the 16-digit form shown above (which
matches the IEEE754 representation for double); half and float values
must, however, be exactly representable as IEEE 754 half and single
precision, respectively. Hexadecimal format is always used for long
double, and there are three forms of long double. The 80-bit format used
by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The
128-bit format used by PowerPC (two adjacent doubles) is represented by
``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is
represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles
will only work if they match the long double format on your target.
The IEEE 16-bit format (half precision) is represented by ``0xH``
followed by 4 hexadecimal digits. All hexadecimal formats are big-endian
(sign bit at the left).
When using the hexadecimal form, constants of types bfloat, half, float, and
double are represented using the 16-digit form shown above (which matches the
IEEE754 representation for double); bfloat, half and float values must, however,
be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
precision respectively. Hexadecimal format is always used for long double, and
there are three forms of long double. The 80-bit format used by x86 is
represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
by 32 hexadecimal digits. Long doubles will only work if they match the long
double format on your target. The IEEE 16-bit format (half precision) is
represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
format is represented by ``0xR`` followed by 4 hexadecimal digits. All
hexadecimal formats are big-endian (sign bit at the left).

There are no constants of type x86_mmx.

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm-c/Core.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ typedef enum {
typedef enum {
LLVMVoidTypeKind, /**< type with no size */
LLVMHalfTypeKind, /**< 16 bit floating point type */
LLVMBFloatTypeKind, /**< 16 bit brain floating point type */
LLVMFloatTypeKind, /**< 32 bit floating point type */
LLVMDoubleTypeKind, /**< 64 bit floating point type */
LLVMX86_FP80TypeKind, /**< 80 bit floating point type (X87) */
Expand Down Expand Up @@ -1163,6 +1164,11 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
*/
LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C);

/**
* Obtain a 16-bit brain floating point type from a context.
*/
LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C);

/**
* Obtain a 32-bit floating point type from a context.
*/
Expand Down Expand Up @@ -1195,6 +1201,7 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
* These map to the functions in this group of the same name.
*/
LLVMTypeRef LLVMHalfType(void);
LLVMTypeRef LLVMBFloatType(void);
LLVMTypeRef LLVMFloatType(void);
LLVMTypeRef LLVMDoubleType(void);
LLVMTypeRef LLVMX86FP80Type(void);
Expand Down
9 changes: 7 additions & 2 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ struct APFloatBase {
/// @{
enum Semantics {
S_IEEEhalf,
S_BFloat,
S_IEEEsingle,
S_IEEEdouble,
S_x87DoubleExtended,
Expand All @@ -162,6 +163,7 @@ struct APFloatBase {
static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);

static const fltSemantics &IEEEhalf() LLVM_READNONE;
static const fltSemantics &BFloat() LLVM_READNONE;
static const fltSemantics &IEEEsingle() LLVM_READNONE;
static const fltSemantics &IEEEdouble() LLVM_READNONE;
static const fltSemantics &IEEEquad() LLVM_READNONE;
Expand Down Expand Up @@ -541,13 +543,15 @@ class IEEEFloat final : public APFloatBase {
/// @}

APInt convertHalfAPFloatToAPInt() const;
APInt convertBFloatAPFloatToAPInt() const;
APInt convertFloatAPFloatToAPInt() const;
APInt convertDoubleAPFloatToAPInt() const;
APInt convertQuadrupleAPFloatToAPInt() const;
APInt convertF80LongDoubleAPFloatToAPInt() const;
APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
void initFromHalfAPInt(const APInt &api);
void initFromBFloatAPInt(const APInt &api);
void initFromFloatAPInt(const APInt &api);
void initFromDoubleAPInt(const APInt &api);
void initFromQuadrupleAPInt(const APInt &api);
Expand Down Expand Up @@ -954,9 +958,10 @@ class APFloat : public APFloatBase {

/// Returns a float which is bitcasted from an all one value int.
///
/// \param Semantics - type float semantics
/// \param BitWidth - Select float type
/// \param isIEEE - If 128 bit number, select between PPC and IEEE
static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false);
static APFloat getAllOnesValue(const fltSemantics &Semantics,
unsigned BitWidth);

/// Used to insert APFloat objects, or objects that contain APFloat objects,
/// into FoldingSets.
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/Bitcode/LLVMBitCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ enum TypeCodes {

TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N]

TYPE_CODE_TOKEN = 22 // TOKEN
TYPE_CODE_TOKEN = 22, // TOKEN

TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT
};

enum OperandBundleTagCode {
Expand Down
34 changes: 18 additions & 16 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,14 +721,15 @@ class ConstantDataArray final : public ConstantDataSequential {
return getImpl(Data, Ty);
}

/// getFP() constructors - Return a constant with array type with an element
/// count and element type of float with precision matching the number of
/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
/// double for 64bits) Note that this can return a ConstantAggregateZero
/// object.
static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
/// getFP() constructors - Return a constant of array type with a float
/// element type taken from argument `ElementType', and count taken from
/// argument `Elts'. The amount of bits of the contained type must match the
/// number of bits of the type contained in the passed in ArrayRef.
/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
/// that this can return a ConstantAggregateZero object.
static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);

/// This method constructs a CDS and initializes it with a text string.
/// The default behavior (AddNull==true) causes a null terminator to
Expand Down Expand Up @@ -780,14 +781,15 @@ class ConstantDataVector final : public ConstantDataSequential {
static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);

/// getFP() constructors - Return a constant with vector type with an element
/// count and element type of float with the precision matching the number of
/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
/// double for 64bits) Note that this can return a ConstantAggregateZero
/// object.
static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
/// getFP() constructors - Return a constant of vector type with a float
/// element type taken from argument `ElementType', and count taken from
/// argument `Elts'. The amount of bits of the contained type must match the
/// number of bits of the type contained in the passed in ArrayRef.
/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
/// that this can return a ConstantAggregateZero object.
static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);

/// Return a ConstantVector with the specified constant in each element.
/// The specified constant has to be a of a compatible type (i8/i16/
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/DataLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
case Type::IntegerTyID:
return TypeSize::Fixed(Ty->getIntegerBitWidth());
case Type::HalfTyID:
case Type::BFloatTyID:
return TypeSize::Fixed(16);
case Type::FloatTyID:
return TypeSize::Fixed(32);
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/IRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ class IRBuilderBase {
return Type::getHalfTy(Context);
}

/// Fetch the type representing a 16-bit brain floating point value.
Type *getBFloatTy() {
return Type::getBFloatTy(Context);
}

/// Fetch the type representing a 32-bit floating point value.
Type *getFloatTy() {
return Type::getFloatTy(Context);
Expand Down
47 changes: 27 additions & 20 deletions llvm/include/llvm/IR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,28 @@ class Type {
///
enum TypeID {
// PrimitiveTypes - make sure LastPrimitiveTyID stays up to date.
VoidTyID = 0, ///< 0: type with no size
HalfTyID, ///< 1: 16-bit floating point type
FloatTyID, ///< 2: 32-bit floating point type
DoubleTyID, ///< 3: 64-bit floating point type
X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 7: Labels
MetadataTyID, ///< 8: Metadata
X86_MMXTyID, ///< 9: MMX vectors (64 bits, X86 specific)
TokenTyID, ///< 10: Tokens
VoidTyID = 0, ///< 0: type with no size
HalfTyID, ///< 1: 16-bit floating point type
BFloatTyID, ///< 2: 16-bit floating point type (7-bit significand)
FloatTyID, ///< 3: 32-bit floating point type
DoubleTyID, ///< 4: 64-bit floating point type
X86_FP80TyID, ///< 5: 80-bit floating point type (X87)
FP128TyID, ///< 6: 128-bit floating point type (112-bit significand)
PPC_FP128TyID, ///< 7: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 8: Labels
MetadataTyID, ///< 9: Metadata
X86_MMXTyID, ///< 10: MMX vectors (64 bits, X86 specific)
TokenTyID, ///< 11: Tokens

// Derived types... see DerivedTypes.h file.
// Make sure FirstDerivedTyID stays up to date!
IntegerTyID, ///< 11: Arbitrary bit width integers
FunctionTyID, ///< 12: Functions
StructTyID, ///< 13: Structures
ArrayTyID, ///< 14: Arrays
PointerTyID, ///< 15: Pointers
FixedVectorTyID, ///< 16: Fixed width SIMD vector type
ScalableVectorTyID ///< 17: Scalable SIMD vector type
IntegerTyID, ///< 12: Arbitrary bit width integers
FunctionTyID, ///< 13: Functions
StructTyID, ///< 14: Structures
ArrayTyID, ///< 15: Arrays
PointerTyID, ///< 16: Pointers
FixedVectorTyID, ///< 17: Fixed width SIMD vector type
ScalableVectorTyID ///< 18: Scalable SIMD vector type
};

private:
Expand Down Expand Up @@ -140,6 +141,9 @@ class Type {
/// Return true if this is 'half', a 16-bit IEEE fp type.
bool isHalfTy() const { return getTypeID() == HalfTyID; }

/// Return true if this is 'bfloat', a 16-bit bfloat type.
bool isBFloatTy() const { return getTypeID() == BFloatTyID; }

/// Return true if this is 'float', a 32-bit IEEE fp type.
bool isFloatTy() const { return getTypeID() == FloatTyID; }

Expand All @@ -157,15 +161,16 @@ class Type {

/// Return true if this is one of the six floating-point types
bool isFloatingPointTy() const {
return getTypeID() == HalfTyID || getTypeID() == FloatTyID ||
getTypeID() == DoubleTyID ||
return getTypeID() == HalfTyID || getTypeID() == BFloatTyID ||
getTypeID() == FloatTyID || getTypeID() == DoubleTyID ||
getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID ||
getTypeID() == PPC_FP128TyID;
}

const fltSemantics &getFltSemantics() const {
switch (getTypeID()) {
case HalfTyID: return APFloat::IEEEhalf();
case BFloatTyID: return APFloat::BFloat();
case FloatTyID: return APFloat::IEEEsingle();
case DoubleTyID: return APFloat::IEEEdouble();
case X86_FP80TyID: return APFloat::x87DoubleExtended();
Expand Down Expand Up @@ -387,6 +392,7 @@ class Type {
static Type *getVoidTy(LLVMContext &C);
static Type *getLabelTy(LLVMContext &C);
static Type *getHalfTy(LLVMContext &C);
static Type *getBFloatTy(LLVMContext &C);
static Type *getFloatTy(LLVMContext &C);
static Type *getDoubleTy(LLVMContext &C);
static Type *getMetadataTy(LLVMContext &C);
Expand Down Expand Up @@ -422,6 +428,7 @@ class Type {
// types as pointee.
//
static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getBFloatPtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0);
static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0);
Expand Down
12 changes: 10 additions & 2 deletions llvm/lib/AsmParser/LLLexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ lltok::Kind LLLexer::LexIdentifier() {

TYPEKEYWORD("void", Type::getVoidTy(Context));
TYPEKEYWORD("half", Type::getHalfTy(Context));
TYPEKEYWORD("bfloat", Type::getBFloatTy(Context));
TYPEKEYWORD("float", Type::getFloatTy(Context));
TYPEKEYWORD("double", Type::getDoubleTy(Context));
TYPEKEYWORD("x86_fp80", Type::getX86_FP80Ty(Context));
Expand Down Expand Up @@ -985,11 +986,13 @@ lltok::Kind LLLexer::LexIdentifier() {
/// HexFP128Constant 0xL[0-9A-Fa-f]+
/// HexPPC128Constant 0xM[0-9A-Fa-f]+
/// HexHalfConstant 0xH[0-9A-Fa-f]+
/// HexBFloatConstant 0xR[0-9A-Fa-f]+
lltok::Kind LLLexer::Lex0x() {
CurPtr = TokStart + 2;

char Kind;
if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') {
if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
CurPtr[0] == 'R') {
Kind = *CurPtr++;
} else {
Kind = 'J';
Expand All @@ -1007,7 +1010,7 @@ lltok::Kind LLLexer::Lex0x() {
if (Kind == 'J') {
// HexFPConstant - Floating point constant represented in IEEE format as a
// hexadecimal number for when exponential notation is not precise enough.
// Half, Float, and double only.
// Half, BFloat, Float, and double only.
APFloatVal = APFloat(APFloat::IEEEdouble(),
APInt(64, HexIntToVal(TokStart + 2, CurPtr)));
return lltok::APFloat;
Expand Down Expand Up @@ -1035,6 +1038,11 @@ lltok::Kind LLLexer::Lex0x() {
APFloatVal = APFloat(APFloat::IEEEhalf(),
APInt(16,HexIntToVal(TokStart+3, CurPtr)));
return lltok::APFloat;
case 'R':
// Brain floating point
APFloatVal = APFloat(APFloat::BFloat(),
APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
return lltok::APFloat;
}
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/AsmParser/LLParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5247,13 +5247,16 @@ bool LLParser::ConvertValIDToValue(Type *Ty, ValID &ID, Value *&V,
!ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
return Error(ID.Loc, "floating point constant invalid for type");

// The lexer has no type info, so builds all half, float, and double FP
// constants as double. Fix this here. Long double does not need this.
// The lexer has no type info, so builds all half, bfloat, float, and double
// FP constants as double. Fix this here. Long double does not need this.
if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
bool Ignored;
if (Ty->isHalfTy())
ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
&Ignored);
else if (Ty->isBFloatTy())
ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
&Ignored);
else if (Ty->isFloatTy())
ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
&Ignored);
Expand Down

0 comments on commit 8c24f33

Please sign in to comment.