diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h index 0dfe9f029f9b1a..87923356a67685 100644 --- a/llvm/include/llvm/IR/Intrinsics.h +++ b/llvm/include/llvm/IR/Intrinsics.h @@ -114,6 +114,8 @@ namespace Intrinsic { MMX, Token, Metadata, + Float8E4M3FN, + FLoat8E5M2, Half, BFloat, Float, diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index bdd8465883fcff..7dce3431cb1646 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -320,6 +320,8 @@ def IIT_I2 : IIT_Int<2, 57>; def IIT_I4 : IIT_Int<4, 58>; def IIT_AARCH64_SVCOUNT : IIT_VT; def IIT_V6 : IIT_Vec<6, 60>; +def IIT_F8E4M3FN : IIT_VT; +def IIT_F8E5M2 : IIT_VT; } defvar IIT_all_FixedTypes = !filter(iit, IIT_all, @@ -553,6 +555,17 @@ def llvm_v32i64_ty : LLVMType; // 32 x i64 def llvm_v1i128_ty : LLVMType; // 1 x i128 +def llvm_v2f8e4m3fn_ty : LLVMType; // 2 x f8e4m3 (__f8e4m3fn) +def llvm_v4f8e4m3fn_ty : LLVMType; // 4 x f8e4m3 (__f8e4m3fn) +def llvm_v8f8e4m3fn_ty : LLVMType; // 8 x f8e4m3 (__f8e4m3fn) +def llvm_v16f8e4m3fn_ty : LLVMType; // 16 x f8e4m3 (__f8e4m3fn) +def llvm_v32f8e4m3fn_ty : LLVMType; // 32 x f8e4m3 (__f8e4m3fn) +def llvm_v2f8e5m2_ty : LLVMType; // 2 x f8e5m2 (__f8e5m2) +def llvm_v4f8e5m2_ty : LLVMType; // 4 x f8e5m2 (__f8e5m2) +def llvm_v8f8e5m2_ty : LLVMType; // 8 x f8e5m2 (__f8e5m2) +def llvm_v16f8e5m2_ty : LLVMType; // 16 x f8e5m2 (__f8e5m2) +def llvm_v32f8e5m2_ty : LLVMType; // 32 x f8e5m2 (__f8e5m2) + def llvm_v2f16_ty : LLVMType; // 2 x half (__fp16) def llvm_v4f16_ty : LLVMType; // 4 x half (__fp16) def llvm_v8f16_ty : LLVMType; // 8 x half (__fp16) diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp index 96953ac49c19b4..c88e6f7eaceed7 100644 --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -1111,6 +1111,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef Infos, case IIT_METADATA: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Metadata, 0)); return; + case IIT_F8E4M3FN: + OutputTable.push_back(IITDescriptor::get(IITDescriptor::Float8E4M3FN, 0)); + return; + case IIT_F8E5M2: + OutputTable.push_back(IITDescriptor::get(IITDescriptor::FLoat8E5M2, 0)); + return; case IIT_F16: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Half, 0)); return; @@ -1357,6 +1363,10 @@ static Type *DecodeFixedType(ArrayRef &Infos, case IITDescriptor::AMX: return Type::getX86_AMXTy(Context); case IITDescriptor::Token: return Type::getTokenTy(Context); case IITDescriptor::Metadata: return Type::getMetadataTy(Context); + case IITDescriptor::Float8E4M3FN: + return Type::getFloat8E4M3FNTy(Context); + case IITDescriptor::FLoat8E5M2: + return Type::getFloat8E5M2Ty(Context); case IITDescriptor::Half: return Type::getHalfTy(Context); case IITDescriptor::BFloat: return Type::getBFloatTy(Context); case IITDescriptor::Float: return Type::getFloatTy(Context); @@ -1516,6 +1526,10 @@ static bool matchIntrinsicType( case IITDescriptor::AMX: return !Ty->isX86_AMXTy(); case IITDescriptor::Token: return !Ty->isTokenTy(); case IITDescriptor::Metadata: return !Ty->isMetadataTy(); + case IITDescriptor::Float8E4M3FN: + return !Ty->isFloat8E4M3FNTy(); + case IITDescriptor::FLoat8E5M2: + return !Ty->isFloat8E5M2Ty(); case IITDescriptor::Half: return !Ty->isHalfTy(); case IITDescriptor::BFloat: return !Ty->isBFloatTy(); case IITDescriptor::Float: return !Ty->isFloatTy();