8 changes: 8 additions & 0 deletions compiler-rt/cmake/builtin-config-ix.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ _Float16 foo(_Float16 x) {
"
)

builtin_check_c_compiler_source(COMPILER_RT_HAS_BFLOAT16
"
__bf16 foo(__bf16 x) {
return x;
}
"
)

builtin_check_c_compiler_source(COMPILER_RT_HAS_ASM_LSE
"
asm(\".arch armv8-a+lse\");
Expand Down
11 changes: 9 additions & 2 deletions compiler-rt/lib/builtins/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,8 @@ set(GENERIC_SOURCES
subvsi3.c
subvti3.c
trampoline_setup.c
truncdfbf2.c
truncdfhf2.c
truncdfsf2.c
truncsfbf2.c
truncsfhf2.c
ucmpdi2.c
ucmpti2.c
Expand All @@ -183,6 +181,15 @@ set(GENERIC_SOURCES
umodti3.c
)

# Build BF16 files only when "__bf16" is available.
if(COMPILER_RT_HAS_BFLOAT16 AND NOT APPLE)
set(GENERIC_SOURCES
${GENERIC_SOURCES}
truncdfbf2.c
truncsfbf2.c
)
endif()

# TODO: Several "tf" files (and divtc3.c, but not multc3.c) are in
# GENERIC_SOURCES instead of here.
set(GENERIC_TF_SOURCES
Expand Down
2 changes: 1 addition & 1 deletion compiler-rt/lib/builtins/fp_trunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ typedef uint16_t dst_rep_t;
static const int dstSigBits = 10;

#elif defined DST_BFLOAT
typedef uint16_t dst_t;
typedef __bf16 dst_t;
typedef uint16_t dst_rep_t;
#define DST_REP_C UINT16_C
static const int dstSigBits = 7;
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class Type {
/// Return true if this is 'bfloat', a 16-bit bfloat type.
bool isBFloatTy() const { return getTypeID() == BFloatTyID; }

/// Return true if this is a 16-bit float type.
bool is16bitFPTy() const {
return getTypeID() == BFloatTyID || getTypeID() == HalfTyID;
}

/// Return true if this is 'float', a 32-bit IEEE fp type.
bool isFloatTy() const { return getTypeID() == FloatTyID; }

Expand Down
57 changes: 56 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setTruncStoreAction(VT, MVT::bf16, Expand);

setOperationAction(ISD::BF16_TO_FP, VT, Expand);
setOperationAction(ISD::FP_TO_BF16, VT, Expand);
setOperationAction(ISD::FP_TO_BF16, VT, Custom);
}

setOperationAction(ISD::PARITY, MVT::i8, Custom);
Expand Down Expand Up @@ -2494,6 +2494,10 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
!Subtarget.hasX87())
return MVT::i32;

if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getRegisterTypeForCallingConv(Context, CC,
VT.changeVectorElementTypeToInteger());

return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}

Expand Down Expand Up @@ -2525,6 +2529,10 @@ unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
return 3;
}

if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getNumRegistersForCallingConv(Context, CC,
VT.changeVectorElementTypeToInteger());

return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}

Expand Down Expand Up @@ -2733,6 +2741,40 @@ unsigned X86TargetLowering::getJumpTableEncoding() const {
return TargetLowering::getJumpTableEncoding();
}

bool X86TargetLowering::splitValueIntoRegisterParts(
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
unsigned NumParts, MVT PartVT, Optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();
if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) {
unsigned ValueBits = ValueVT.getSizeInBits();
unsigned PartBits = PartVT.getSizeInBits();
Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val);
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val);
Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
Parts[0] = Val;
return true;
}
return false;
}

SDValue X86TargetLowering::joinRegisterPartsIntoValue(
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT, Optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) {
unsigned ValueBits = ValueVT.getSizeInBits();
unsigned PartBits = PartVT.getSizeInBits();
SDValue Val = Parts[0];

Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val);
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val);
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
return Val;
}
return SDValue();
}

bool X86TargetLowering::useSoftFloat() const {
return Subtarget.useSoftFloat();
}
Expand Down Expand Up @@ -23019,6 +23061,18 @@ static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) {
return Res;
}

SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
RTLIB::Libcall LC =
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
SDValue Res =
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16,
DAG.getBitcast(MVT::i32, Res));
}

/// Depending on uarch and/or optimizing for size, we might prefer to use a
/// vector operation in place of the typical scalar operation.
static SDValue lowerAddSubToHorizontalOp(SDValue Op, SelectionDAG &DAG,
Expand Down Expand Up @@ -32211,6 +32265,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::STRICT_FP16_TO_FP: return LowerFP16_TO_FP(Op, DAG);
case ISD::FP_TO_FP16:
case ISD::STRICT_FP_TO_FP16: return LowerFP_TO_FP16(Op, DAG);
case ISD::FP_TO_BF16: return LowerFP_TO_BF16(Op, DAG);
case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG);
case ISD::STORE: return LowerStore(Op, Subtarget, DAG);
case ISD::FADD:
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,7 @@ namespace llvm {
SDValue lowerFaddFsub(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;

SDValue
LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
Expand All @@ -1621,6 +1622,17 @@ namespace llvm {
MachineBasicBlock *Entry,
const SmallVectorImpl<MachineBasicBlock *> &Exits) const override;

bool
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
SDValue *Parts, unsigned NumParts, MVT PartVT,
Optional<CallingConv::ID> CC) const override;

SDValue
joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT,
Optional<CallingConv::ID> CC) const override;

bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;

bool mayBeEmittedAsTailCall(const CallInst *CI) const override;
Expand Down
240 changes: 217 additions & 23 deletions llvm/test/CodeGen/X86/bfloat.ll
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s

define void @add(ptr %pa, ptr %pb, ptr %pc) {
define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
; CHECK-LABEL: add:
; CHECK: # %bb.0:
; CHECK-NEXT: pushq %rbx
; CHECK-NEXT: .cfi_def_cfa_offset 16
; CHECK-NEXT: .cfi_offset %rbx, -16
; CHECK-NEXT: movq %rdx, %rbx
; CHECK-NEXT: movzwl (%rdi), %eax
; CHECK-NEXT: shll $16, %eax
Expand All @@ -16,9 +14,9 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) {
; CHECK-NEXT: movd %eax, %xmm0
; CHECK-NEXT: addss %xmm1, %xmm0
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: movw %ax, (%rbx)
; CHECK-NEXT: popq %rbx
; CHECK-NEXT: .cfi_def_cfa_offset 8
; CHECK-NEXT: retq
%a = load bfloat, ptr %pa
%b = load bfloat, ptr %pb
Expand All @@ -27,38 +25,48 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) {
ret void
}

define void @add_double(ptr %pa, ptr %pb, ptr %pc) {
define bfloat @add2(bfloat %a, bfloat %b) nounwind {
; CHECK-LABEL: add2:
; CHECK: # %bb.0:
; CHECK-NEXT: pushq %rax
; CHECK-NEXT: movd %xmm1, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movd %eax, %xmm1
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movd %eax, %xmm0
; CHECK-NEXT: addss %xmm1, %xmm0
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: popq %rax
; CHECK-NEXT: retq
%add = fadd bfloat %a, %b
ret bfloat %add
}

define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
; CHECK-LABEL: add_double:
; CHECK: # %bb.0:
; CHECK-NEXT: pushq %r14
; CHECK-NEXT: .cfi_def_cfa_offset 16
; CHECK-NEXT: pushq %rbx
; CHECK-NEXT: .cfi_def_cfa_offset 24
; CHECK-NEXT: pushq %rax
; CHECK-NEXT: .cfi_def_cfa_offset 32
; CHECK-NEXT: .cfi_offset %rbx, -24
; CHECK-NEXT: .cfi_offset %r14, -16
; CHECK-NEXT: movq %rdx, %r14
; CHECK-NEXT: movq %rsi, %rbx
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
; CHECK-NEXT: callq __truncdfbf2@PLT
; CHECK-NEXT: # kill: def $ax killed $ax def $eax
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movl %eax, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
; CHECK-NEXT: callq __truncdfbf2@PLT
; CHECK-NEXT: # kill: def $ax killed $ax def $eax
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movd %eax, %xmm0
; CHECK-NEXT: addss {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: cvtss2sd %xmm0, %xmm0
; CHECK-NEXT: movsd %xmm0, (%r14)
; CHECK-NEXT: addq $8, %rsp
; CHECK-NEXT: .cfi_def_cfa_offset 24
; CHECK-NEXT: popq %rbx
; CHECK-NEXT: .cfi_def_cfa_offset 16
; CHECK-NEXT: popq %r14
; CHECK-NEXT: .cfi_def_cfa_offset 8
; CHECK-NEXT: retq
%la = load double, ptr %pa
%a = fptrunc double %la to bfloat
Expand All @@ -70,29 +78,68 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) {
ret void
}

define void @add_constant(ptr %pa, ptr %pc) {
define double @add_double2(double %da, double %db) nounwind {
; CHECK-LABEL: add_double2:
; CHECK: # %bb.0:
; CHECK-NEXT: subq $24, %rsp
; CHECK-NEXT: movsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
; CHECK-NEXT: callq __truncdfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movl %eax, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero
; CHECK-NEXT: callq __truncdfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movd %eax, %xmm0
; CHECK-NEXT: addss {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: cvtss2sd %xmm0, %xmm0
; CHECK-NEXT: addq $24, %rsp
; CHECK-NEXT: retq
%a = fptrunc double %da to bfloat
%b = fptrunc double %db to bfloat
%add = fadd bfloat %a, %b
%dadd = fpext bfloat %add to double
ret double %dadd
}

define void @add_constant(ptr %pa, ptr %pc) nounwind {
; CHECK-LABEL: add_constant:
; CHECK: # %bb.0:
; CHECK-NEXT: pushq %rbx
; CHECK-NEXT: .cfi_def_cfa_offset 16
; CHECK-NEXT: .cfi_offset %rbx, -16
; CHECK-NEXT: movq %rsi, %rbx
; CHECK-NEXT: movzwl (%rdi), %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movd %eax, %xmm0
; CHECK-NEXT: addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: movw %ax, (%rbx)
; CHECK-NEXT: popq %rbx
; CHECK-NEXT: .cfi_def_cfa_offset 8
; CHECK-NEXT: retq
%a = load bfloat, ptr %pa
%add = fadd bfloat %a, 1.0
store bfloat %add, ptr %pc
ret void
}

define void @store_constant(ptr %pc) {
define bfloat @add_constant2(bfloat %a) nounwind {
; CHECK-LABEL: add_constant2:
; CHECK: # %bb.0:
; CHECK-NEXT: pushq %rax
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: movd %eax, %xmm0
; CHECK-NEXT: addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: popq %rax
; CHECK-NEXT: retq
%add = fadd bfloat %a, 1.0
ret bfloat %add
}

define void @store_constant(ptr %pc) nounwind {
; CHECK-LABEL: store_constant:
; CHECK: # %bb.0:
; CHECK-NEXT: movw $16256, (%rdi) # imm = 0x3F80
Expand All @@ -101,7 +148,7 @@ define void @store_constant(ptr %pc) {
ret void
}

define void @fold_ext_trunc(ptr %pa, ptr %pc) {
define void @fold_ext_trunc(ptr %pa, ptr %pc) nounwind {
; CHECK-LABEL: fold_ext_trunc:
; CHECK: # %bb.0:
; CHECK-NEXT: movzwl (%rdi), %eax
Expand All @@ -113,3 +160,150 @@ define void @fold_ext_trunc(ptr %pa, ptr %pc) {
store bfloat %trunc, ptr %pc
ret void
}

define bfloat @fold_ext_trunc2(bfloat %a) nounwind {
; CHECK-LABEL: fold_ext_trunc2:
; CHECK: # %bb.0:
; CHECK-NEXT: retq
%ext = fpext bfloat %a to float
%trunc = fptrunc float %ext to bfloat
ret bfloat %trunc
}

define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
; CHECK-LABEL: addv:
; CHECK: # %bb.0:
; CHECK-NEXT: pushq %rbp
; CHECK-NEXT: pushq %r14
; CHECK-NEXT: pushq %rbx
; CHECK-NEXT: subq $32, %rsp
; CHECK-NEXT: movq %xmm1, %rax
; CHECK-NEXT: movq %rax, %rcx
; CHECK-NEXT: shrq $32, %rcx
; CHECK-NEXT: shll $16, %ecx
; CHECK-NEXT: movd %ecx, %xmm2
; CHECK-NEXT: movq %xmm0, %rcx
; CHECK-NEXT: movq %rcx, %rdx
; CHECK-NEXT: shrq $32, %rdx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm3
; CHECK-NEXT: addss %xmm2, %xmm3
; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: movq %rax, %rdx
; CHECK-NEXT: shrq $48, %rdx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm2
; CHECK-NEXT: movq %rcx, %rdx
; CHECK-NEXT: shrq $48, %rdx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm3
; CHECK-NEXT: addss %xmm2, %xmm3
; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: movl %eax, %edx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm2
; CHECK-NEXT: movl %ecx, %edx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm3
; CHECK-NEXT: addss %xmm2, %xmm3
; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: andl $-65536, %eax # imm = 0xFFFF0000
; CHECK-NEXT: movd %eax, %xmm2
; CHECK-NEXT: andl $-65536, %ecx # imm = 0xFFFF0000
; CHECK-NEXT: movd %ecx, %xmm3
; CHECK-NEXT: addss %xmm2, %xmm3
; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: pshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
; CHECK-NEXT: movq %xmm1, %rax
; CHECK-NEXT: movq %rax, %rcx
; CHECK-NEXT: shrq $32, %rcx
; CHECK-NEXT: shll $16, %ecx
; CHECK-NEXT: movd %ecx, %xmm1
; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
; CHECK-NEXT: movq %xmm0, %rcx
; CHECK-NEXT: movq %rcx, %rdx
; CHECK-NEXT: shrq $32, %rdx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm0
; CHECK-NEXT: addss %xmm1, %xmm0
; CHECK-NEXT: movss %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: movq %rax, %rdx
; CHECK-NEXT: shrq $48, %rdx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm0
; CHECK-NEXT: movq %rcx, %rdx
; CHECK-NEXT: shrq $48, %rdx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm1
; CHECK-NEXT: addss %xmm0, %xmm1
; CHECK-NEXT: movss %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: movl %eax, %edx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm0
; CHECK-NEXT: movl %ecx, %edx
; CHECK-NEXT: shll $16, %edx
; CHECK-NEXT: movd %edx, %xmm1
; CHECK-NEXT: addss %xmm0, %xmm1
; CHECK-NEXT: movss %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill
; CHECK-NEXT: andl $-65536, %eax # imm = 0xFFFF0000
; CHECK-NEXT: movd %eax, %xmm1
; CHECK-NEXT: andl $-65536, %ecx # imm = 0xFFFF0000
; CHECK-NEXT: movd %ecx, %xmm0
; CHECK-NEXT: addss %xmm1, %xmm0
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %ebx
; CHECK-NEXT: shll $16, %ebx
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: movzwl %ax, %r14d
; CHECK-NEXT: orl %ebx, %r14d
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %ebp
; CHECK-NEXT: shll $16, %ebp
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: movzwl %ax, %ebx
; CHECK-NEXT: orl %ebp, %ebx
; CHECK-NEXT: shlq $32, %rbx
; CHECK-NEXT: orq %r14, %rbx
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %ebp
; CHECK-NEXT: shll $16, %ebp
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: movzwl %ax, %r14d
; CHECK-NEXT: orl %ebp, %r14d
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %ebp
; CHECK-NEXT: shll $16, %ebp
; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload
; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero
; CHECK-NEXT: callq __truncsfbf2@PLT
; CHECK-NEXT: movd %xmm0, %eax
; CHECK-NEXT: movzwl %ax, %eax
; CHECK-NEXT: orl %ebp, %eax
; CHECK-NEXT: shlq $32, %rax
; CHECK-NEXT: orq %r14, %rax
; CHECK-NEXT: movq %rax, %xmm0
; CHECK-NEXT: movq %rbx, %xmm1
; CHECK-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
; CHECK-NEXT: addq $32, %rsp
; CHECK-NEXT: popq %rbx
; CHECK-NEXT: popq %r14
; CHECK-NEXT: popq %rbp
; CHECK-NEXT: retq
%add = fadd <8 x bfloat> %a, %b
ret <8 x bfloat> %add
}
43 changes: 29 additions & 14 deletions mlir/lib/ExecutionEngine/Float16bits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "mlir/ExecutionEngine/Float16bits.h"
#include <cmath>
#include <cstring>

namespace {

Expand Down Expand Up @@ -146,29 +147,43 @@ std::ostream &operator<<(std::ostream &os, const bf16 &d) {
return os;
}

// Provide a float->bfloat conversion routine in case the runtime doesn't have
// one.
extern "C" uint16_t
// Mark these symbols as weak so they don't conflict when compiler-rt also
// defines them.
#define ATTR_WEAK
#ifdef __has_attribute
#if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \
!defined(_WIN32)
__attribute__((__weak__))
#undef ATTR_WEAK
#define ATTR_WEAK __attribute__((__weak__))
#endif
#endif

#if defined(__x86_64__)
// On x86 bfloat16 is passed in SSE registers. Since both float and __bf16
// are passed in the same register we can use the wider type and careful casting
// to conform to x86_64 psABI. This only works with the assumption that we're
// dealing with little-endian values passed in wider registers.
// Ideally this would directly use __bf16, but that type isn't supported by all
// compilers.
using BF16ABIType = float;
#else
// Default to uint16_t if we have nothing else.
using BF16ABIType = uint16_t;
#endif
__truncsfbf2(float f) {
return float2bfloat(f);

// Provide a float->bfloat conversion routine in case the runtime doesn't have
// one.
extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) {
uint16_t bf = float2bfloat(f);
// The output can be a float type, bitcast it from uint16_t.
BF16ABIType ret = 0;
std::memcpy(&ret, &bf, sizeof(bf));
return ret;
}

// Provide a double->bfloat conversion routine in case the runtime doesn't have
// one.
extern "C" uint16_t
#ifdef __has_attribute
#if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \
!defined(_WIN32)
__attribute__((__weak__))
#endif
#endif
__truncdfbf2(double d) {
extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
// This does a double rounding step, but it's precise enough for our use
// cases.
return __truncsfbf2(static_cast<float>(d));
Expand Down