Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][Float8] Add two kinds float8 IR type #89900

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JinjinLi868
Copy link

@JinjinLi868 JinjinLi868 commented Apr 24, 2024

Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for
ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves
like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit
exponent and a 3-bit mantissa.

The series patches(for IR. MVT. intrinsic):
#89900
#89901
#89902

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-llvm-ir

Author: None (JinjinLi868)

Changes

Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for
ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves
like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit
exponent and a 3-bit mantissa.


Patch is 46.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89900.diff

23 Files Affected:

  • (modified) llvm/docs/BitCodeFormat.rst (+16)
  • (modified) llvm/docs/LangRef.rst (+22-11)
  • (modified) llvm/include/llvm-c/Core.h (+15)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+3)
  • (modified) llvm/include/llvm/IR/Constants.h (+2)
  • (modified) llvm/include/llvm/IR/DataLayout.h (+3)
  • (modified) llvm/include/llvm/IR/IRBuilder.h (+10)
  • (modified) llvm/include/llvm/IR/Type.h (+17)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+36-21)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+10-3)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+25-1)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+17-15)
  • (modified) llvm/lib/CodeGen/MIRParser/MILexer.cpp (+1-1)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+25-14)
  • (modified) llvm/lib/IR/Constants.cpp (+70-10)
  • (modified) llvm/lib/IR/Core.cpp (+18-3)
  • (modified) llvm/lib/IR/DataLayout.cpp (+3)
  • (modified) llvm/lib/IR/Function.cpp (+13-11)
  • (modified) llvm/lib/IR/LLVMContextImpl.cpp (+1)
  • (modified) llvm/lib/IR/LLVMContextImpl.h (+2-2)
  • (modified) llvm/lib/IR/Type.cpp (+30-14)
  • (added) llvm/test/Assembler/float8.ll (+71)
  • (modified) llvm/tools/llvm-c-test/echo.cpp (+5-1)
diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst
index 46af2e421a258c..bd9b1f87422585 100644
--- a/llvm/docs/BitCodeFormat.rst
+++ b/llvm/docs/BitCodeFormat.rst
@@ -1139,6 +1139,22 @@ TYPE_CODE_VOID Record
 
 The ``VOID`` record (code 2) adds a ``void`` type to the type table.
 
+TYPE_CODE_Float8E5M2 Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E5M2]``
+
+The ``Float8E5M2`` record (code 27) adds a ``float8e5m2`` (8-bit floating point)
+type to the type table.
+
+TYPE_CODE_Float8E4M3FN Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E4M3FN]``
+
+The ``Float8E4M3FN`` record (code 28) adds a ``float8e4m3fn`` (8-bit floating
+point) type to the type table.
+
 TYPE_CODE_HALF Record
 ^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 9592929d79feb4..3106dc0cc25d5e 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -3847,6 +3847,14 @@ Floating-Point Types
    * - Type
      - Description
 
+   * - ``float8e5m2``
+     - 16-bit floating-point value(2-bit significand)
+
+   * - ``float8e4m3fn``
+     - 16-bit floating-point value(3-bit significand), there are no infinity
+       values, and NaN is represented with the exponent and mantissa bits set
+       to all 1s
+
    * - ``half``
      - 16-bit floating-point value
 
@@ -3871,9 +3879,9 @@ Floating-Point Types
    * - ``ppc_fp128``
      - 128-bit floating-point value (two 64-bits)
 
-The binary format of half, float, double, and fp128 correspond to the
-IEEE-754-2008 specifications for binary16, binary32, binary64, and binary128
-respectively.
+The binary format of float8e5m2, half, float, double, and fp128 correspond
+to the IEEE-754-2008 specifications for binary8, binary16, binary32, binary64,
+and binary128 respectively.
 
 X86_amx Type
 """"""""""""
@@ -4329,20 +4337,23 @@ number of digits. For example, NaN's, infinities, and other special
 values are represented in their IEEE hexadecimal format so that assembly
 and disassembly do not cause any bits to change in the constants.
 
-When using the hexadecimal form, constants of types bfloat, half, float, and
-double are represented using the 16-digit form shown above (which matches the
-IEEE754 representation for double); bfloat, half and float values must, however,
-be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+When using the hexadecimal form, constants of types float8e5m2, float8e4m3fn,
+bfloat, half, float, and double are represented using the 16-digit form shown
+above (which matches the IEEE754 representation for double); float8e5m2,
+float8e4m3fn, bfloat, half and float values must, however, be exactly representable
+as float8e5m2, float8e4m3fn, bfloat, IEEE 754 half, and IEEE 754 single
 precision respectively. Hexadecimal format is always used for long double, and
 there are three forms of long double. The 80-bit format used by x86 is
 represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
 used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
 hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
 by 32 hexadecimal digits. Long doubles will only work if they match the long
-double format on your target.  The IEEE 16-bit format (half precision) is
-represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
-format is represented by ``0xR`` followed by 4 hexadecimal digits. All
-hexadecimal formats are big-endian (sign bit at the left).
+double format on your target. The IEEE 8-bit format (floate5m2 precision) is
+represented by ``0xS`` followed by 2 hexadecimal digits. The float8e4m3fn 8-bit
+format is represented by ``0xQ`` followed by 2 hexadecimal digits. The IEEE 16-bit
+format (half precision) is represented by ``0xH`` followed by 4 hexadecimal digits.
+The bfloat 16-bit format is represented by ``0xR`` followed by 4 hexadecimal digits.
+All hexadecimal formats are big-endian (sign bit at the left).
 
 There are no constants of type x86_mmx and x86_amx.
 
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 0b03f3b36fcdd3..7cc958beccb62d 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -167,6 +167,8 @@ typedef enum {
   LLVMBFloatTypeKind,    /**< 16 bit brain floating point type */
   LLVMX86_AMXTypeKind,   /**< X86 AMX */
   LLVMTargetExtTypeKind, /**< Target extension type */
+  LLVMFloat8E5M2TypeKind, /**< 8 bit floating point with 2 bit mantissa */
+  LLVMFloat8E4M3FNTypeKind, /**< 8 bit floating point with 3 bit mantissa */
 } LLVMTypeKind;
 
 typedef enum {
@@ -1298,6 +1300,17 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
  * @{
  */
 
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E5M2TypeInContext(LLVMContextRef C);
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E4M3FNTypeInContext(LLVMContextRef C);
+
 /**
  * Obtain a 16-bit floating point type from a context.
  */
@@ -1339,6 +1352,8 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
  *
  * These map to the functions in this group of the same name.
  */
+LLVMTypeRef LLVMFloat8E5M2Type(void);
+LLVMTypeRef LLVMFloat8E4M3FNType(void);
 LLVMTypeRef LLVMHalfType(void);
 LLVMTypeRef LLVMBFloatType(void);
 LLVMTypeRef LLVMFloatType(void);
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..ce6d639c2455c4 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -177,6 +177,9 @@ enum TypeCodes {
   TYPE_CODE_OPAQUE_POINTER = 25, // OPAQUE_POINTER: [addrspace]
 
   TYPE_CODE_TARGET_TYPE = 26, // TARGET_TYPE
+
+  TYPE_CODE_Float8E5M2 = 27, // Float8E5M2
+  TYPE_CODE_Float8E4M3FN = 28, // Float8E4M3FN
 };
 
 enum OperandBundleTagCode {
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..3c82b74a111aa5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -733,6 +733,7 @@ class ConstantDataArray final : public ConstantDataSequential {
   /// number of bits of the type contained in the passed in ArrayRef.
   /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
   /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
@@ -805,6 +806,7 @@ class ConstantDataVector final : public ConstantDataSequential {
   /// number of bits of the type contained in the passed in ArrayRef.
   /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
   /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index d14adfe1590be5..2f0c55d8e758c6 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -687,6 +687,9 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
     return getStructLayout(cast<StructType>(Ty))->getSizeInBits();
   case Type::IntegerTyID:
     return TypeSize::getFixed(Ty->getIntegerBitWidth());
+  case Type::Float8E5M2TyID:
+  case Type::Float8E4M3FNTyID:
+    return TypeSize::getFixed(8);
   case Type::HalfTyID:
   case Type::BFloatTyID:
     return TypeSize::getFixed(16);
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index b6534a1962a2f5..de981586b4cbe7 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -540,6 +540,16 @@ class IRBuilderBase {
     return Type::getIntNTy(Context, N);
   }
 
+  /// Fetch the type representing a 8-bit e5m2 floating point value.
+  Type *getFloat8E5M2Ty() {
+    return Type::getFloat8E5M2Ty(Context);
+  }
+
+  /// Fetch the type representing a 8-bit e4m3fn floating point value.
+  Type *getFloat8E4M3FNTy() {
+    return Type::getFloat8E4M3FNTy(Context);
+  }
+
   /// Fetch the type representing a 16-bit floating point value.
   Type *getHalfTy() {
     return Type::getHalfTy(Context);
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 1f0133c08e7d60..bf9f63d2cdda4c 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -55,6 +55,8 @@ class Type {
     // PrimitiveTypes
     HalfTyID = 0,  ///< 16-bit floating point type
     BFloatTyID,    ///< 16-bit floating point type (7-bit significand)
+    Float8E5M2TyID,   ///< 8-bit floating type (5 Bit exponent)
+    Float8E4M3FNTyID, ///< 8-bit floating type (4 Bit exponent)
     FloatTyID,     ///< 32-bit floating point type
     DoubleTyID,    ///< 64-bit floating point type
     X86_FP80TyID,  ///< 80-bit floating point type (X87)
@@ -139,6 +141,17 @@ class Type {
   /// Return true if this is 'void'.
   bool isVoidTy() const { return getTypeID() == VoidTyID; }
 
+  /// Return true if this is 'F8E5M2'.
+  bool isFloat8E5M2Ty() const { return getTypeID() == Float8E5M2TyID; }
+
+  /// Return true if this is 'F8E4M3FN'.
+  bool isFloat8E4M3FNTy() const { return getTypeID() == Float8E4M3FNTyID; }
+
+  /// Return true if this is an 8-bit float type.
+  bool is8BitFPTy() const {
+    return getTypeID() == Float8E5M2TyID || getTypeID() == Float8E4M3FNTyID;
+  }
+
   /// Return true if this is 'half', a 16-bit IEEE fp type.
   bool isHalfTy() const { return getTypeID() == HalfTyID; }
 
@@ -174,6 +187,8 @@ class Type {
     case FloatTyID:
     case HalfTyID:
     case BFloatTyID:
+    case Float8E5M2TyID:
+    case Float8E4M3FNTyID:
     case FP128TyID:
       return true;
     default:
@@ -445,6 +460,8 @@ class Type {
   //
   static Type *getVoidTy(LLVMContext &C);
   static Type *getLabelTy(LLVMContext &C);
+  static Type *getFloat8E5M2Ty(LLVMContext &C);
+  static Type *getFloat8E4M3FNTy(LLVMContext &C);
   static Type *getHalfTy(LLVMContext &C);
   static Type *getBFloatTy(LLVMContext &C);
   static Type *getFloatTy(LLVMContext &C);
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..aabd6262304f16 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -825,20 +825,22 @@ lltok::Kind LLLexer::LexIdentifier() {
     }                                                                          \
   } while (false)
 
-  TYPEKEYWORD("void",      Type::getVoidTy(Context));
-  TYPEKEYWORD("half",      Type::getHalfTy(Context));
-  TYPEKEYWORD("bfloat",    Type::getBFloatTy(Context));
-  TYPEKEYWORD("float",     Type::getFloatTy(Context));
-  TYPEKEYWORD("double",    Type::getDoubleTy(Context));
-  TYPEKEYWORD("x86_fp80",  Type::getX86_FP80Ty(Context));
-  TYPEKEYWORD("fp128",     Type::getFP128Ty(Context));
-  TYPEKEYWORD("ppc_fp128", Type::getPPC_FP128Ty(Context));
-  TYPEKEYWORD("label",     Type::getLabelTy(Context));
-  TYPEKEYWORD("metadata",  Type::getMetadataTy(Context));
-  TYPEKEYWORD("x86_mmx",   Type::getX86_MMXTy(Context));
-  TYPEKEYWORD("x86_amx",   Type::getX86_AMXTy(Context));
-  TYPEKEYWORD("token",     Type::getTokenTy(Context));
-  TYPEKEYWORD("ptr",       PointerType::getUnqual(Context));
+  TYPEKEYWORD("void",          Type::getVoidTy(Context));
+  TYPEKEYWORD("float8e5m2",    Type::getFloat8E5M2Ty(Context));
+  TYPEKEYWORD("float8e4m3fn",  Type::getFloat8E4M3FNTy(Context));
+  TYPEKEYWORD("half",          Type::getHalfTy(Context));
+  TYPEKEYWORD("bfloat",        Type::getBFloatTy(Context));
+  TYPEKEYWORD("float",         Type::getFloatTy(Context));
+  TYPEKEYWORD("double",        Type::getDoubleTy(Context));
+  TYPEKEYWORD("x86_fp80",      Type::getX86_FP80Ty(Context));
+  TYPEKEYWORD("fp128",         Type::getFP128Ty(Context));
+  TYPEKEYWORD("ppc_fp128",     Type::getPPC_FP128Ty(Context));
+  TYPEKEYWORD("label",         Type::getLabelTy(Context));
+  TYPEKEYWORD("metadata",      Type::getMetadataTy(Context));
+  TYPEKEYWORD("x86_mmx",       Type::getX86_MMXTy(Context));
+  TYPEKEYWORD("x86_amx",       Type::getX86_AMXTy(Context));
+  TYPEKEYWORD("token",         Type::getTokenTy(Context));
+  TYPEKEYWORD("ptr",           PointerType::getUnqual(Context));
 
 #undef TYPEKEYWORD
 
@@ -1006,18 +1008,21 @@ lltok::Kind LLLexer::LexIdentifier() {
 
 /// Lex all tokens that start with a 0x prefix, knowing they match and are not
 /// labels.
-///    HexFPConstant     0x[0-9A-Fa-f]+
-///    HexFP80Constant   0xK[0-9A-Fa-f]+
-///    HexFP128Constant  0xL[0-9A-Fa-f]+
-///    HexPPC128Constant 0xM[0-9A-Fa-f]+
-///    HexHalfConstant   0xH[0-9A-Fa-f]+
-///    HexBFloatConstant 0xR[0-9A-Fa-f]+
+///    HexFPConstant         0x[0-9A-Fa-f]+
+///    HexFP80Constant       0xK[0-9A-Fa-f]+
+///    HexFP128Constant      0xL[0-9A-Fa-f]+
+///    HexPPC128Constant     0xM[0-9A-Fa-f]+
+///    HexHalfConstant       0xH[0-9A-Fa-f]+
+///    HexBFloatConstant     0xR[0-9A-Fa-f]+
+///    HexFP8E4M3FNConstant  0xQ[0-9A-Fa-f]+
+///    HexFP8E5M2Constant    0xS[0-9A-Fa-f]+
+
 lltok::Kind LLLexer::Lex0x() {
   CurPtr = TokStart + 2;
 
   char Kind;
   if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
-      CurPtr[0] == 'R') {
+      CurPtr[0] == 'R' || CurPtr[0] == 'Q' || CurPtr[0] == 'S') {
     Kind = *CurPtr++;
   } else {
     Kind = 'J';
@@ -1068,6 +1073,16 @@ lltok::Kind LLLexer::Lex0x() {
     APFloatVal = APFloat(APFloat::BFloat(),
                          APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
     return lltok::APFloat;
+  case 'Q':
+    // FP8E4M3FN
+    APFloatVal = APFloat(APFloat::Float8E4M3FN(),
+                         APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+    return lltok::APFloat;
+  case 'S':
+    // FP8E5M2
+    APFloatVal = APFloat(APFloat::Float8E5M2(),
+                         APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+    return lltok::APFloat;
   }
 }
 
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 63104129f8c2df..d32e154c8baf66 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -5998,13 +5998,20 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
         !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
-    // The lexer has no type info, so builds all half, bfloat, float, and double
-    // FP constants as double.  Fix this here.  Long double does not need this.
+    // The lexer has no type info, so builds all float8e5m2, float8e4m3fn, half,
+    // bfloat, float, and double FP constants as double.  Fix this here. Long
+    // double does not need this.
     if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (Ty->isFloat8E5M2Ty())
+        ID.APFloatVal.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
+      else if (Ty->isFloat8E4M3FNTy())
+        ID.APFloatVal.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
+      else if (Ty->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       else if (Ty->isBFloatTy())
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 0b7fcd88418894..3e3ec1b2664089 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2404,6 +2404,12 @@ Error BitcodeReader::parseTypeTableBody() {
     case bitc::TYPE_CODE_VOID:      // VOID
       ResultTy = Type::getVoidTy(Context);
       break;
+    case bitc::TYPE_CODE_Float8E4M3FN: // FP8E4M3FN
+      ResultTy = Type::getFloat8E4M3FNTy(Context);
+      break;
+    case bitc::TYPE_CODE_Float8E5M2:   // FP8E5M2
+      ResultTy = Type::getFloat8E5M2Ty(Context);
+      break;
     case bitc::TYPE_CODE_HALF:     // HALF
       ResultTy = Type::getHalfTy(Context);
       break;
@@ -3138,7 +3144,13 @@ Error BitcodeReader::parseConstants() {
         return error("Invalid float const record");
 
       auto *ScalarTy = CurTy->getScalarType();
-      if (ScalarTy->isHalfTy())
+      if (ScalarTy->isFloat8E4M3FNTy())
+        V = ConstantFP::get(Context, APFloat(APFloat::Float8E4M3FN(),
+                                             APInt(8, (uint8_t)Record[0])));
+      else if (ScalarTy->isFloat8E5M2Ty())
+        V = ConstantFP::get(Context, APFloat(APFloat::Float8E5M2(),
+                                             APInt(8, (uint8_t)Record[0])));
+      else if (ScalarTy->isHalfTy())
         V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
                                            APInt(16, (uint16_t)Record[0])));
       else if (ScalarTy->isBFloatTy())
@@ -3234,6 +3246,18 @@ Error BitcodeReader::parseConstants() {
           V = ConstantDataVector::get(Context, Elts);
         else
           V = ConstantDataArray::get(Context, Elts);
+      } else if (EltTy->isFloat8E4M3FNTy()) {
+        SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
+      } else if (EltTy->isFloat8E5M2Ty()) {
+        SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isHalfTy()) {
         SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6d01e3b4d82189..46b5d0ed9440ee 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1043,19 +1043,21 @@ void ModuleBitcodeWriter::writeTypeTable() {
     unsigned Code = 0;
 
     switch (T->getTypeID()) {
-    case Type::VoidTyID:      Code = bitc::TYPE_CODE_VOID;      break;
-    case Type::HalfTyID:      Code = bitc::TYPE_CODE_HALF;      break;
-    case Type::BFloatTyID:    Code = bitc::TYPE_CODE_BFLOAT;    break;
-    case Type::FloatTyID:     Code = bitc::TYPE_CODE_FLOAT;     break;
-    case Type::DoubleTyID:    Code = bitc::TYPE_CODE_DOUBLE;    break;
-    case Type::X86_FP80TyID:  Code = bitc::TYPE_CODE_X86_FP80;  break;
-    case Type::FP128TyID:     Code = bitc::TYPE_CODE_FP128;     break;
-    case Type::PPC_FP128TyID: Code = bitc::TYPE_CODE_PPC_FP128; break;
-    case Type::LabelTyID:     Code = bitc::TYPE_CODE_LABEL;     break;
-    case Type::MetadataTyID:  Code = bitc::TYPE_CODE_METADATA;  break;
-    case Type::X86_MMXTyID:   Code = bitc::TYPE_CODE_X86_MMX;   break;
-    case Type::X86_AMXTyID:   Code = bitc::TYPE_CODE_X86_AMX;   break;
-    case Type::TokenTyID:     Code = bitc::TYPE_CODE_TOKEN;     break;
+    case Type::VoidTyID:          Code = bitc::TYPE_CODE_VOID;          break;
+    case Type::Float8E4M3FNTyID:  Code = bitc::TYPE_CODE_Float8E4M3FN;  break;
+    case Type::Float8E5M2TyID:    Code = bitc::TYPE_CODE_Float8E5M2;    break;
+    case Type::HalfTyID:          Code = bitc::TYPE_CODE_HALF;          break;
+    case Type::BFloatTyID:        Code = bitc::TYPE_CODE_BFLOAT;        break;
+    case Type::FloatTyID:         Code = bitc::TYPE_CODE_FLOAT;         break;
+    case Type::DoubleTyID:        Code = bitc::TYPE_CODE_DOUBLE;        break;
+    case Type::X86_FP80TyID:      Code = bitc::TYPE_CODE_X86_FP80;      break;
+    case Type::FP128TyID:         Code = bitc::TYPE_CODE_FP128;         break;
+    case Type::PPC_FP128TyID:     Code = bitc::TYPE_CODE_PPC_FP128;     break;
+    case Type::LabelTyID:         Code = bitc::TYPE_CODE_LABEL;         break;
+    case Type::MetadataTyID:      Code = bitc::TYPE_CODE_M...
[truncated]

@JinjinLi868
Copy link
Author

@arsenm

Copy link
Contributor

@jcranmer-intel jcranmer-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests for invalid constants in the IR parser? In particular, I'd like to make sure that the float8e4m3fn is handled properly in the double-hex-constant mode, given that it lacks infinities and has but one NaN value.

@@ -3847,6 +3847,14 @@ Floating-Point Types
* - Type
- Description

* - ``float8e5m2``
- 16-bit floating-point value(2-bit significand)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 8-bit types, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, my bad, changed

IEEE-754-2008 specifications for binary16, binary32, binary64, and binary128
respectively.
The binary format of float8e5m2, half, float, double, and fp128 correspond
to the IEEE-754-2008 specifications for binary8, binary16, binary32, binary64,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IEEE 754-2008 (nor IEEE 754-2019, for that matter) doesn't define a binary8 type. And the table for binaryk actually applies to k >= 128 and k % 32 == 0. And even if you were to ignore that restriction, the formulas you would get for k =8 gives a p of 9 bits and a w of -1 bits.

@@ -174,6 +187,8 @@ class Type {
case FloatTyID:
case HalfTyID:
case BFloatTyID:
case Float8E5M2TyID:
case Float8E4M3FNTyID:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type definitely doesn't qualify--it's lacking infinity and sNaN for one thing. I'm not familiar enough with these types to ascertain the answer for Float8E5M2.

* - ``float8e4m3fn``
- 16-bit floating-point value(3-bit significand), there are no infinity
values, and NaN is represented with the exponent and mantissa bits set
to all 1s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any references to the definition of the type you can provide, like bfloat provides references to hardware ISAs that contain more detail?

(Presumably the reason to finally get around to adding these types in LLVM IR is to enable hardware instructions, so references to hardware ISAs are ideal).

Copy link
Author

@JinjinLi868 JinjinLi868 Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reference is that https://arxiv.org/pdf/2209.05433.
i am not sure that put the website on the Lang�Ref.rst is better or not.

format is represented by ``0xQ`` followed by 2 hexadecimal digits. The IEEE 16-bit
format (half precision) is represented by ``0xH`` followed by 4 hexadecimal digits.
The bfloat 16-bit format is represented by ``0xR`` followed by 4 hexadecimal digits.
All hexadecimal formats are big-endian (sign bit at the left).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This paragraph is getting long and hard to follow. Perhaps it should be a table or list instead?

(As an aside, I don't really like the policy of representing smaller-than-double types as hex doubles, and the need to have different prefixes for different float types is annoying. IMHO, we should just change the float constant part of the LLVM IR, but that is orthogonal to this change.)

Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for
ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves
like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit
exponent and a 3-bit mantissa.
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're going to need a new RFC for this. Please check https://discourse.llvm.org/t/add-llvm-type-support-for-fp8-data-types-f8e4m3-and-f8e5m2/67598 and https://reviews.llvm.org/D140088 to get some background on why this has been rejected in this past.

@JinjinLi868
Copy link
Author

Can you add some tests for invalid constants in the IR parser? In particular, I'd like to make sure that the float8e4m3fn is handled properly in the double-hex-constant mode, given that it lacks infinities and has but one NaN value.

done

@JinjinLi868
Copy link
Author

You're going to need a new RFC for this. Please check https://discourse.llvm.org/t/add-llvm-type-support-for-fp8-data-types-f8e4m3-and-f8e5m2/67598 and https://reviews.llvm.org/D140088 to get some background on why this has been rejected in this past.

so, we look to accept more stable data types rather than the unstable data types that come from machine learning? @nikic

@arsenm
Copy link
Contributor

arsenm commented Apr 30, 2024

You're going to need a new RFC for this. Please check https://discourse.llvm.org/t/add-llvm-type-support-for-fp8-data-types-f8e4m3-and-f8e5m2/67598 and https://reviews.llvm.org/D140088 to get some background on why this has been rejected in this past.

so, we look to accept more stable data types rather than the unstable data types that come from machine learning? @nikic

The support surface area for something like this is really quite large. There's a lot of legalization work to support fundamental operations, which is kind of pointless if there is no underlying support on any target for them. Would first class, generic intrinsics for converting between each of these formats to other float types cover the real use cases?

@RKSimon
Copy link
Collaborator

RKSimon commented Jun 12, 2024

@JinjinLi868 reverse-ping?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants