Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][Float8] Add Float8 IR intrinsics support #89902

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ namespace Intrinsic {
MMX,
Token,
Metadata,
Float8E4M3FN,
FLoat8E5M2,
Half,
BFloat,
Float,
Expand Down
13 changes: 13 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def IIT_I2 : IIT_Int<2, 57>;
def IIT_I4 : IIT_Int<4, 58>;
def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
def IIT_V6 : IIT_Vec<6, 60>;
def IIT_F8E4M3FN : IIT_VT<f8e4m3fn, 61>;
def IIT_F8E5M2 : IIT_VT<f8e5m2, 62>;
}

defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
Expand Down Expand Up @@ -553,6 +555,17 @@ def llvm_v32i64_ty : LLVMType<v32i64>; // 32 x i64

def llvm_v1i128_ty : LLVMType<v1i128>; // 1 x i128

def llvm_v2f8e4m3fn_ty : LLVMType<v2f8e4m3fn>; // 2 x f8e4m3 (__f8e4m3fn)
def llvm_v4f8e4m3fn_ty : LLVMType<v4f8e4m3fn>; // 4 x f8e4m3 (__f8e4m3fn)
def llvm_v8f8e4m3fn_ty : LLVMType<v8f8e4m3fn>; // 8 x f8e4m3 (__f8e4m3fn)
def llvm_v16f8e4m3fn_ty : LLVMType<v16f8e4m3fn>; // 16 x f8e4m3 (__f8e4m3fn)
def llvm_v32f8e4m3fn_ty : LLVMType<v32f8e4m3fn>; // 32 x f8e4m3 (__f8e4m3fn)
def llvm_v2f8e5m2_ty : LLVMType<v2f8e5m2>; // 2 x f8e5m2 (__f8e5m2)
def llvm_v4f8e5m2_ty : LLVMType<v4f8e5m2>; // 4 x f8e5m2 (__f8e5m2)
def llvm_v8f8e5m2_ty : LLVMType<v8f8e5m2>; // 8 x f8e5m2 (__f8e5m2)
def llvm_v16f8e5m2_ty : LLVMType<v16f8e5m2>; // 16 x f8e5m2 (__f8e5m2)
def llvm_v32f8e5m2_ty : LLVMType<v32f8e5m2>; // 32 x f8e5m2 (__f8e5m2)

def llvm_v2f16_ty : LLVMType<v2f16>; // 2 x half (__fp16)
def llvm_v4f16_ty : LLVMType<v4f16>; // 4 x half (__fp16)
def llvm_v8f16_ty : LLVMType<v8f16>; // 8 x half (__fp16)
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/IR/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
case IIT_METADATA:
OutputTable.push_back(IITDescriptor::get(IITDescriptor::Metadata, 0));
return;
case IIT_F8E4M3FN:
OutputTable.push_back(IITDescriptor::get(IITDescriptor::Float8E4M3FN, 0));
return;
case IIT_F8E5M2:
OutputTable.push_back(IITDescriptor::get(IITDescriptor::FLoat8E5M2, 0));
return;
case IIT_F16:
OutputTable.push_back(IITDescriptor::get(IITDescriptor::Half, 0));
return;
Expand Down Expand Up @@ -1357,6 +1363,10 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
case IITDescriptor::AMX: return Type::getX86_AMXTy(Context);
case IITDescriptor::Token: return Type::getTokenTy(Context);
case IITDescriptor::Metadata: return Type::getMetadataTy(Context);
case IITDescriptor::Float8E4M3FN:
return Type::getFloat8E4M3FNTy(Context);
case IITDescriptor::FLoat8E5M2:
return Type::getFloat8E5M2Ty(Context);
case IITDescriptor::Half: return Type::getHalfTy(Context);
case IITDescriptor::BFloat: return Type::getBFloatTy(Context);
case IITDescriptor::Float: return Type::getFloatTy(Context);
Expand Down Expand Up @@ -1516,6 +1526,10 @@ static bool matchIntrinsicType(
case IITDescriptor::AMX: return !Ty->isX86_AMXTy();
case IITDescriptor::Token: return !Ty->isTokenTy();
case IITDescriptor::Metadata: return !Ty->isMetadataTy();
case IITDescriptor::Float8E4M3FN:
return !Ty->isFloat8E4M3FNTy();
case IITDescriptor::FLoat8E5M2:
return !Ty->isFloat8E5M2Ty();
case IITDescriptor::Half: return !Ty->isHalfTy();
case IITDescriptor::BFloat: return !Ty->isBFloatTy();
case IITDescriptor::Float: return !Ty->isFloatTy();
Expand Down
Loading