-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR][Float8] Add two kinds float8 IR type #89900
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-ir Author: None (JinjinLi868) ChangesSupport two classes Float8(float8e5m2 and float8e4m3fn) IR type for Patch is 46.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89900.diff 23 Files Affected:
diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst
index 46af2e421a258c..bd9b1f87422585 100644
--- a/llvm/docs/BitCodeFormat.rst
+++ b/llvm/docs/BitCodeFormat.rst
@@ -1139,6 +1139,22 @@ TYPE_CODE_VOID Record
The ``VOID`` record (code 2) adds a ``void`` type to the type table.
+TYPE_CODE_Float8E5M2 Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E5M2]``
+
+The ``Float8E5M2`` record (code 27) adds a ``float8e5m2`` (8-bit floating point)
+type to the type table.
+
+TYPE_CODE_Float8E4M3FN Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E4M3FN]``
+
+The ``Float8E4M3FN`` record (code 28) adds a ``float8e4m3fn`` (8-bit floating
+point) type to the type table.
+
TYPE_CODE_HALF Record
^^^^^^^^^^^^^^^^^^^^^
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 9592929d79feb4..3106dc0cc25d5e 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -3847,6 +3847,14 @@ Floating-Point Types
* - Type
- Description
+ * - ``float8e5m2``
+ - 16-bit floating-point value(2-bit significand)
+
+ * - ``float8e4m3fn``
+ - 16-bit floating-point value(3-bit significand), there are no infinity
+ values, and NaN is represented with the exponent and mantissa bits set
+ to all 1s
+
* - ``half``
- 16-bit floating-point value
@@ -3871,9 +3879,9 @@ Floating-Point Types
* - ``ppc_fp128``
- 128-bit floating-point value (two 64-bits)
-The binary format of half, float, double, and fp128 correspond to the
-IEEE-754-2008 specifications for binary16, binary32, binary64, and binary128
-respectively.
+The binary format of float8e5m2, half, float, double, and fp128 correspond
+to the IEEE-754-2008 specifications for binary8, binary16, binary32, binary64,
+and binary128 respectively.
X86_amx Type
""""""""""""
@@ -4329,20 +4337,23 @@ number of digits. For example, NaN's, infinities, and other special
values are represented in their IEEE hexadecimal format so that assembly
and disassembly do not cause any bits to change in the constants.
-When using the hexadecimal form, constants of types bfloat, half, float, and
-double are represented using the 16-digit form shown above (which matches the
-IEEE754 representation for double); bfloat, half and float values must, however,
-be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+When using the hexadecimal form, constants of types float8e5m2, float8e4m3fn,
+bfloat, half, float, and double are represented using the 16-digit form shown
+above (which matches the IEEE754 representation for double); float8e5m2,
+float8e4m3fn, bfloat, half and float values must, however, be exactly representable
+as float8e5m2, float8e4m3fn, bfloat, IEEE 754 half, and IEEE 754 single
precision respectively. Hexadecimal format is always used for long double, and
there are three forms of long double. The 80-bit format used by x86 is
represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
by 32 hexadecimal digits. Long doubles will only work if they match the long
-double format on your target. The IEEE 16-bit format (half precision) is
-represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
-format is represented by ``0xR`` followed by 4 hexadecimal digits. All
-hexadecimal formats are big-endian (sign bit at the left).
+double format on your target. The IEEE 8-bit format (floate5m2 precision) is
+represented by ``0xS`` followed by 2 hexadecimal digits. The float8e4m3fn 8-bit
+format is represented by ``0xQ`` followed by 2 hexadecimal digits. The IEEE 16-bit
+format (half precision) is represented by ``0xH`` followed by 4 hexadecimal digits.
+The bfloat 16-bit format is represented by ``0xR`` followed by 4 hexadecimal digits.
+All hexadecimal formats are big-endian (sign bit at the left).
There are no constants of type x86_mmx and x86_amx.
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 0b03f3b36fcdd3..7cc958beccb62d 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -167,6 +167,8 @@ typedef enum {
LLVMBFloatTypeKind, /**< 16 bit brain floating point type */
LLVMX86_AMXTypeKind, /**< X86 AMX */
LLVMTargetExtTypeKind, /**< Target extension type */
+ LLVMFloat8E5M2TypeKind, /**< 8 bit floating point with 2 bit mantissa */
+ LLVMFloat8E4M3FNTypeKind, /**< 8 bit floating point with 3 bit mantissa */
} LLVMTypeKind;
typedef enum {
@@ -1298,6 +1300,17 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
* @{
*/
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E5M2TypeInContext(LLVMContextRef C);
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E4M3FNTypeInContext(LLVMContextRef C);
+
/**
* Obtain a 16-bit floating point type from a context.
*/
@@ -1339,6 +1352,8 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
*
* These map to the functions in this group of the same name.
*/
+LLVMTypeRef LLVMFloat8E5M2Type(void);
+LLVMTypeRef LLVMFloat8E4M3FNType(void);
LLVMTypeRef LLVMHalfType(void);
LLVMTypeRef LLVMBFloatType(void);
LLVMTypeRef LLVMFloatType(void);
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..ce6d639c2455c4 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -177,6 +177,9 @@ enum TypeCodes {
TYPE_CODE_OPAQUE_POINTER = 25, // OPAQUE_POINTER: [addrspace]
TYPE_CODE_TARGET_TYPE = 26, // TARGET_TYPE
+
+ TYPE_CODE_Float8E5M2 = 27, // Float8E5M2
+ TYPE_CODE_Float8E4M3FN = 28, // Float8E4M3FN
};
enum OperandBundleTagCode {
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..3c82b74a111aa5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -733,6 +733,7 @@ class ConstantDataArray final : public ConstantDataSequential {
/// number of bits of the type contained in the passed in ArrayRef.
/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
/// that this can return a ConstantAggregateZero object.
+ static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
@@ -805,6 +806,7 @@ class ConstantDataVector final : public ConstantDataSequential {
/// number of bits of the type contained in the passed in ArrayRef.
/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
/// that this can return a ConstantAggregateZero object.
+ static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index d14adfe1590be5..2f0c55d8e758c6 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -687,6 +687,9 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
return getStructLayout(cast<StructType>(Ty))->getSizeInBits();
case Type::IntegerTyID:
return TypeSize::getFixed(Ty->getIntegerBitWidth());
+ case Type::Float8E5M2TyID:
+ case Type::Float8E4M3FNTyID:
+ return TypeSize::getFixed(8);
case Type::HalfTyID:
case Type::BFloatTyID:
return TypeSize::getFixed(16);
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index b6534a1962a2f5..de981586b4cbe7 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -540,6 +540,16 @@ class IRBuilderBase {
return Type::getIntNTy(Context, N);
}
+ /// Fetch the type representing a 8-bit e5m2 floating point value.
+ Type *getFloat8E5M2Ty() {
+ return Type::getFloat8E5M2Ty(Context);
+ }
+
+ /// Fetch the type representing a 8-bit e4m3fn floating point value.
+ Type *getFloat8E4M3FNTy() {
+ return Type::getFloat8E4M3FNTy(Context);
+ }
+
/// Fetch the type representing a 16-bit floating point value.
Type *getHalfTy() {
return Type::getHalfTy(Context);
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 1f0133c08e7d60..bf9f63d2cdda4c 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -55,6 +55,8 @@ class Type {
// PrimitiveTypes
HalfTyID = 0, ///< 16-bit floating point type
BFloatTyID, ///< 16-bit floating point type (7-bit significand)
+ Float8E5M2TyID, ///< 8-bit floating type (5 Bit exponent)
+ Float8E4M3FNTyID, ///< 8-bit floating type (4 Bit exponent)
FloatTyID, ///< 32-bit floating point type
DoubleTyID, ///< 64-bit floating point type
X86_FP80TyID, ///< 80-bit floating point type (X87)
@@ -139,6 +141,17 @@ class Type {
/// Return true if this is 'void'.
bool isVoidTy() const { return getTypeID() == VoidTyID; }
+ /// Return true if this is 'F8E5M2'.
+ bool isFloat8E5M2Ty() const { return getTypeID() == Float8E5M2TyID; }
+
+ /// Return true if this is 'F8E4M3FN'.
+ bool isFloat8E4M3FNTy() const { return getTypeID() == Float8E4M3FNTyID; }
+
+ /// Return true if this is an 8-bit float type.
+ bool is8BitFPTy() const {
+ return getTypeID() == Float8E5M2TyID || getTypeID() == Float8E4M3FNTyID;
+ }
+
/// Return true if this is 'half', a 16-bit IEEE fp type.
bool isHalfTy() const { return getTypeID() == HalfTyID; }
@@ -174,6 +187,8 @@ class Type {
case FloatTyID:
case HalfTyID:
case BFloatTyID:
+ case Float8E5M2TyID:
+ case Float8E4M3FNTyID:
case FP128TyID:
return true;
default:
@@ -445,6 +460,8 @@ class Type {
//
static Type *getVoidTy(LLVMContext &C);
static Type *getLabelTy(LLVMContext &C);
+ static Type *getFloat8E5M2Ty(LLVMContext &C);
+ static Type *getFloat8E4M3FNTy(LLVMContext &C);
static Type *getHalfTy(LLVMContext &C);
static Type *getBFloatTy(LLVMContext &C);
static Type *getFloatTy(LLVMContext &C);
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..aabd6262304f16 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -825,20 +825,22 @@ lltok::Kind LLLexer::LexIdentifier() {
} \
} while (false)
- TYPEKEYWORD("void", Type::getVoidTy(Context));
- TYPEKEYWORD("half", Type::getHalfTy(Context));
- TYPEKEYWORD("bfloat", Type::getBFloatTy(Context));
- TYPEKEYWORD("float", Type::getFloatTy(Context));
- TYPEKEYWORD("double", Type::getDoubleTy(Context));
- TYPEKEYWORD("x86_fp80", Type::getX86_FP80Ty(Context));
- TYPEKEYWORD("fp128", Type::getFP128Ty(Context));
- TYPEKEYWORD("ppc_fp128", Type::getPPC_FP128Ty(Context));
- TYPEKEYWORD("label", Type::getLabelTy(Context));
- TYPEKEYWORD("metadata", Type::getMetadataTy(Context));
- TYPEKEYWORD("x86_mmx", Type::getX86_MMXTy(Context));
- TYPEKEYWORD("x86_amx", Type::getX86_AMXTy(Context));
- TYPEKEYWORD("token", Type::getTokenTy(Context));
- TYPEKEYWORD("ptr", PointerType::getUnqual(Context));
+ TYPEKEYWORD("void", Type::getVoidTy(Context));
+ TYPEKEYWORD("float8e5m2", Type::getFloat8E5M2Ty(Context));
+ TYPEKEYWORD("float8e4m3fn", Type::getFloat8E4M3FNTy(Context));
+ TYPEKEYWORD("half", Type::getHalfTy(Context));
+ TYPEKEYWORD("bfloat", Type::getBFloatTy(Context));
+ TYPEKEYWORD("float", Type::getFloatTy(Context));
+ TYPEKEYWORD("double", Type::getDoubleTy(Context));
+ TYPEKEYWORD("x86_fp80", Type::getX86_FP80Ty(Context));
+ TYPEKEYWORD("fp128", Type::getFP128Ty(Context));
+ TYPEKEYWORD("ppc_fp128", Type::getPPC_FP128Ty(Context));
+ TYPEKEYWORD("label", Type::getLabelTy(Context));
+ TYPEKEYWORD("metadata", Type::getMetadataTy(Context));
+ TYPEKEYWORD("x86_mmx", Type::getX86_MMXTy(Context));
+ TYPEKEYWORD("x86_amx", Type::getX86_AMXTy(Context));
+ TYPEKEYWORD("token", Type::getTokenTy(Context));
+ TYPEKEYWORD("ptr", PointerType::getUnqual(Context));
#undef TYPEKEYWORD
@@ -1006,18 +1008,21 @@ lltok::Kind LLLexer::LexIdentifier() {
/// Lex all tokens that start with a 0x prefix, knowing they match and are not
/// labels.
-/// HexFPConstant 0x[0-9A-Fa-f]+
-/// HexFP80Constant 0xK[0-9A-Fa-f]+
-/// HexFP128Constant 0xL[0-9A-Fa-f]+
-/// HexPPC128Constant 0xM[0-9A-Fa-f]+
-/// HexHalfConstant 0xH[0-9A-Fa-f]+
-/// HexBFloatConstant 0xR[0-9A-Fa-f]+
+/// HexFPConstant 0x[0-9A-Fa-f]+
+/// HexFP80Constant 0xK[0-9A-Fa-f]+
+/// HexFP128Constant 0xL[0-9A-Fa-f]+
+/// HexPPC128Constant 0xM[0-9A-Fa-f]+
+/// HexHalfConstant 0xH[0-9A-Fa-f]+
+/// HexBFloatConstant 0xR[0-9A-Fa-f]+
+/// HexFP8E4M3FNConstant 0xQ[0-9A-Fa-f]+
+/// HexFP8E5M2Constant 0xS[0-9A-Fa-f]+
+
lltok::Kind LLLexer::Lex0x() {
CurPtr = TokStart + 2;
char Kind;
if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
- CurPtr[0] == 'R') {
+ CurPtr[0] == 'R' || CurPtr[0] == 'Q' || CurPtr[0] == 'S') {
Kind = *CurPtr++;
} else {
Kind = 'J';
@@ -1068,6 +1073,16 @@ lltok::Kind LLLexer::Lex0x() {
APFloatVal = APFloat(APFloat::BFloat(),
APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
return lltok::APFloat;
+ case 'Q':
+ // FP8E4M3FN
+ APFloatVal = APFloat(APFloat::Float8E4M3FN(),
+ APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+ return lltok::APFloat;
+ case 'S':
+ // FP8E5M2
+ APFloatVal = APFloat(APFloat::Float8E5M2(),
+ APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+ return lltok::APFloat;
}
}
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 63104129f8c2df..d32e154c8baf66 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -5998,13 +5998,20 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
!ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
return error(ID.Loc, "floating point constant invalid for type");
- // The lexer has no type info, so builds all half, bfloat, float, and double
- // FP constants as double. Fix this here. Long double does not need this.
+ // The lexer has no type info, so builds all float8e5m2, float8e4m3fn, half,
+ // bfloat, float, and double FP constants as double. Fix this here. Long
+ // double does not need this.
if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
// Check for signaling before potentially converting and losing that info.
bool IsSNAN = ID.APFloatVal.isSignaling();
bool Ignored;
- if (Ty->isHalfTy())
+ if (Ty->isFloat8E5M2Ty())
+ ID.APFloatVal.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+ &Ignored);
+ else if (Ty->isFloat8E4M3FNTy())
+ ID.APFloatVal.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+ &Ignored);
+ else if (Ty->isHalfTy())
ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
&Ignored);
else if (Ty->isBFloatTy())
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 0b7fcd88418894..3e3ec1b2664089 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2404,6 +2404,12 @@ Error BitcodeReader::parseTypeTableBody() {
case bitc::TYPE_CODE_VOID: // VOID
ResultTy = Type::getVoidTy(Context);
break;
+ case bitc::TYPE_CODE_Float8E4M3FN: // FP8E4M3FN
+ ResultTy = Type::getFloat8E4M3FNTy(Context);
+ break;
+ case bitc::TYPE_CODE_Float8E5M2: // FP8E5M2
+ ResultTy = Type::getFloat8E5M2Ty(Context);
+ break;
case bitc::TYPE_CODE_HALF: // HALF
ResultTy = Type::getHalfTy(Context);
break;
@@ -3138,7 +3144,13 @@ Error BitcodeReader::parseConstants() {
return error("Invalid float const record");
auto *ScalarTy = CurTy->getScalarType();
- if (ScalarTy->isHalfTy())
+ if (ScalarTy->isFloat8E4M3FNTy())
+ V = ConstantFP::get(Context, APFloat(APFloat::Float8E4M3FN(),
+ APInt(8, (uint8_t)Record[0])));
+ else if (ScalarTy->isFloat8E5M2Ty())
+ V = ConstantFP::get(Context, APFloat(APFloat::Float8E5M2(),
+ APInt(8, (uint8_t)Record[0])));
+ else if (ScalarTy->isHalfTy())
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
APInt(16, (uint16_t)Record[0])));
else if (ScalarTy->isBFloatTy())
@@ -3234,6 +3246,18 @@ Error BitcodeReader::parseConstants() {
V = ConstantDataVector::get(Context, Elts);
else
V = ConstantDataArray::get(Context, Elts);
+ } else if (EltTy->isFloat8E4M3FNTy()) {
+ SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+ if (isa<VectorType>(CurTy))
+ V = ConstantDataVector::getFP(EltTy, Elts);
+ else
+ V = ConstantDataArray::getFP(EltTy, Elts);
+ } else if (EltTy->isFloat8E5M2Ty()) {
+ SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+ if (isa<VectorType>(CurTy))
+ V = ConstantDataVector::getFP(EltTy, Elts);
+ else
+ V = ConstantDataArray::getFP(EltTy, Elts);
} else if (EltTy->isHalfTy()) {
SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
if (isa<VectorType>(CurTy))
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6d01e3b4d82189..46b5d0ed9440ee 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1043,19 +1043,21 @@ void ModuleBitcodeWriter::writeTypeTable() {
unsigned Code = 0;
switch (T->getTypeID()) {
- case Type::VoidTyID: Code = bitc::TYPE_CODE_VOID; break;
- case Type::HalfTyID: Code = bitc::TYPE_CODE_HALF; break;
- case Type::BFloatTyID: Code = bitc::TYPE_CODE_BFLOAT; break;
- case Type::FloatTyID: Code = bitc::TYPE_CODE_FLOAT; break;
- case Type::DoubleTyID: Code = bitc::TYPE_CODE_DOUBLE; break;
- case Type::X86_FP80TyID: Code = bitc::TYPE_CODE_X86_FP80; break;
- case Type::FP128TyID: Code = bitc::TYPE_CODE_FP128; break;
- case Type::PPC_FP128TyID: Code = bitc::TYPE_CODE_PPC_FP128; break;
- case Type::LabelTyID: Code = bitc::TYPE_CODE_LABEL; break;
- case Type::MetadataTyID: Code = bitc::TYPE_CODE_METADATA; break;
- case Type::X86_MMXTyID: Code = bitc::TYPE_CODE_X86_MMX; break;
- case Type::X86_AMXTyID: Code = bitc::TYPE_CODE_X86_AMX; break;
- case Type::TokenTyID: Code = bitc::TYPE_CODE_TOKEN; break;
+ case Type::VoidTyID: Code = bitc::TYPE_CODE_VOID; break;
+ case Type::Float8E4M3FNTyID: Code = bitc::TYPE_CODE_Float8E4M3FN; break;
+ case Type::Float8E5M2TyID: Code = bitc::TYPE_CODE_Float8E5M2; break;
+ case Type::HalfTyID: Code = bitc::TYPE_CODE_HALF; break;
+ case Type::BFloatTyID: Code = bitc::TYPE_CODE_BFLOAT; break;
+ case Type::FloatTyID: Code = bitc::TYPE_CODE_FLOAT; break;
+ case Type::DoubleTyID: Code = bitc::TYPE_CODE_DOUBLE; break;
+ case Type::X86_FP80TyID: Code = bitc::TYPE_CODE_X86_FP80; break;
+ case Type::FP128TyID: Code = bitc::TYPE_CODE_FP128; break;
+ case Type::PPC_FP128TyID: Code = bitc::TYPE_CODE_PPC_FP128; break;
+ case Type::LabelTyID: Code = bitc::TYPE_CODE_LABEL; break;
+ case Type::MetadataTyID: Code = bitc::TYPE_CODE_M...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some tests for invalid constants in the IR parser? In particular, I'd like to make sure that the float8e4m3fn is handled properly in the double-hex-constant mode, given that it lacks infinities and has but one NaN value.
llvm/docs/LangRef.rst
Outdated
@@ -3847,6 +3847,14 @@ Floating-Point Types | |||
* - Type | |||
- Description | |||
|
|||
* - ``float8e5m2`` | |||
- 16-bit floating-point value(2-bit significand) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are 8-bit types, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, my bad, changed
IEEE-754-2008 specifications for binary16, binary32, binary64, and binary128 | ||
respectively. | ||
The binary format of float8e5m2, half, float, double, and fp128 correspond | ||
to the IEEE-754-2008 specifications for binary8, binary16, binary32, binary64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IEEE 754-2008 (nor IEEE 754-2019, for that matter) doesn't define a binary8 type. And the table for binaryk actually applies to k >= 128 and k % 32 == 0. And even if you were to ignore that restriction, the formulas you would get for k =8 gives a p of 9 bits and a w of -1 bits.
@@ -174,6 +187,8 @@ class Type { | |||
case FloatTyID: | |||
case HalfTyID: | |||
case BFloatTyID: | |||
case Float8E5M2TyID: | |||
case Float8E4M3FNTyID: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This type definitely doesn't qualify--it's lacking infinity and sNaN for one thing. I'm not familiar enough with these types to ascertain the answer for Float8E5M2.
* - ``float8e4m3fn`` | ||
- 16-bit floating-point value(3-bit significand), there are no infinity | ||
values, and NaN is represented with the exponent and mantissa bits set | ||
to all 1s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any references to the definition of the type you can provide, like bfloat
provides references to hardware ISAs that contain more detail?
(Presumably the reason to finally get around to adding these types in LLVM IR is to enable hardware instructions, so references to hardware ISAs are ideal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the reference is that https://arxiv.org/pdf/2209.05433.
i am not sure that put the website on the Lang�Ref.rst is better or not.
format is represented by ``0xQ`` followed by 2 hexadecimal digits. The IEEE 16-bit | ||
format (half precision) is represented by ``0xH`` followed by 4 hexadecimal digits. | ||
The bfloat 16-bit format is represented by ``0xR`` followed by 4 hexadecimal digits. | ||
All hexadecimal formats are big-endian (sign bit at the left). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This paragraph is getting long and hard to follow. Perhaps it should be a table or list instead?
(As an aside, I don't really like the policy of representing smaller-than-double types as hex doubles, and the need to have different prefixes for different float types is annoying. IMHO, we should just change the float constant part of the LLVM IR, but that is orthogonal to this change.)
Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit exponent and a 3-bit mantissa.
cc6299f
to
6866553
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're going to need a new RFC for this. Please check https://discourse.llvm.org/t/add-llvm-type-support-for-fp8-data-types-f8e4m3-and-f8e5m2/67598 and https://reviews.llvm.org/D140088 to get some background on why this has been rejected in this past.
done |
so, we look to accept more stable data types rather than the unstable data types that come from machine learning? @nikic |
The support surface area for something like this is really quite large. There's a lot of legalization work to support fundamental operations, which is kind of pointless if there is no underlying support on any target for them. Would first class, generic intrinsics for converting between each of these formats to other float types cover the real use cases? |
@JinjinLi868 reverse-ping? |
Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for
ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves
like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit
exponent and a 3-bit mantissa.
The series patches(for IR. MVT. intrinsic):
#89900
#89901
#89902