Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][BF16] Try to use f16 for lowering #76901

Merged
merged 5 commits into from
Jan 5, 2024
Merged

Conversation

phoebewang
Copy link
Contributor

@phoebewang phoebewang commented Jan 4, 2024

This patch fixes BF16 32-bit ABI problem: https://godbolt.org/z/6dMnh8jGG

@phoebewang phoebewang marked this pull request as ready for review January 4, 2024 07:53
@phoebewang phoebewang changed the title [X86][BF16][WIP] Try to use f16 for lowering [X86][BF16] Try to use f16 for lowering Jan 4, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

Changes

Patch is 102.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76901.diff

4 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+6-5)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (-10)
  • (modified) llvm/lib/Target/X86/X86ISelLoweringCall.cpp (+3-34)
  • (modified) llvm/test/CodeGen/X86/bfloat.ll (+600-724)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index cd56529bfa0fd8..a9ac6aa6558441 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -7475,10 +7475,12 @@ static SDValue buildFromShuffleMostly(SDValue Op, SelectionDAG &DAG) {
 static SDValue LowerBUILD_VECTORvXbf16(SDValue Op, SelectionDAG &DAG,
                                        const X86Subtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
-  MVT IVT = VT.changeVectorElementTypeToInteger();
+  MVT IVT =
+      VT.changeVectorElementType(Subtarget.hasFP16() ? MVT::f16 : MVT::i16);
   SmallVector<SDValue, 16> NewOps;
   for (unsigned I = 0, E = Op.getNumOperands(); I != E; ++I)
-    NewOps.push_back(DAG.getBitcast(MVT::i16, Op.getOperand(I)));
+    NewOps.push_back(DAG.getBitcast(Subtarget.hasFP16() ? MVT::f16 : MVT::i16,
+                                    Op.getOperand(I)));
   SDValue Res = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(), IVT, NewOps);
   return DAG.getBitcast(VT, Res);
 }
@@ -21515,9 +21517,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
   RTLIB::Libcall LC =
       RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
   SDValue Res =
-      makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
-  return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16,
-                     DAG.getBitcast(MVT::i32, Res));
+      makeLibCall(DAG, LC, MVT::f16, Op.getOperand(0), CallOptions, DL).first;
+  return DAG.getBitcast(MVT::i16, Res);
 }
 
 /// Depending on uarch and/or optimizing for size, we might prefer to use a
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 9bd1622cb0d3a6..32745400a38b7e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1714,16 +1714,6 @@ namespace llvm {
       MachineBasicBlock *Entry,
       const SmallVectorImpl<MachineBasicBlock *> &Exits) const override;
 
-    bool splitValueIntoRegisterParts(
-        SelectionDAG & DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
-        unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC)
-        const override;
-
-    SDValue joinRegisterPartsIntoValue(
-        SelectionDAG & DAG, const SDLoc &DL, const SDValue *Parts,
-        unsigned NumParts, MVT PartVT, EVT ValueVT,
-        std::optional<CallingConv::ID> CC) const override;
-
     bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;
 
     bool mayBeEmittedAsTailCall(const CallInst *CI) const override;
diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
index b8b5421b900501..d75bd4171fde9d 100644
--- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
+++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
@@ -127,6 +127,9 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
     return getRegisterTypeForCallingConv(Context, CC,
                                          VT.changeVectorElementType(MVT::f16));
 
+  if (VT == MVT::bf16)
+    return MVT::f16;
+
   return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
 }
 
@@ -421,40 +424,6 @@ unsigned X86TargetLowering::getJumpTableEncoding() const {
   return TargetLowering::getJumpTableEncoding();
 }
 
-bool X86TargetLowering::splitValueIntoRegisterParts(
-    SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
-    unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
-  bool IsABIRegCopy = CC.has_value();
-  EVT ValueVT = Val.getValueType();
-  if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) {
-    unsigned ValueBits = ValueVT.getSizeInBits();
-    unsigned PartBits = PartVT.getSizeInBits();
-    Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val);
-    Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val);
-    Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
-    Parts[0] = Val;
-    return true;
-  }
-  return false;
-}
-
-SDValue X86TargetLowering::joinRegisterPartsIntoValue(
-    SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
-    MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
-  bool IsABIRegCopy = CC.has_value();
-  if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) {
-    unsigned ValueBits = ValueVT.getSizeInBits();
-    unsigned PartBits = PartVT.getSizeInBits();
-    SDValue Val = Parts[0];
-
-    Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val);
-    Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val);
-    Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
-    return Val;
-  }
-  return SDValue();
-}
-
 bool X86TargetLowering::useSoftFloat() const {
   return Subtarget.useSoftFloat();
 }
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 7ef362619d5fd0..b309f47e4b7190 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -3,7 +3,7 @@
 ; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s --check-prefixes=CHECK,SSE2
 ; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16,avx512vl | FileCheck %s --check-prefixes=CHECK,AVX,F16,BF16
 ; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avx512bf16,avx512fp16,avx512vl | FileCheck %s --check-prefixes=CHECK,AVX,F16,FP16
-; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avxneconvert,f16c | FileCheck %s --check-prefixes=CHECK,AVX,AVXNC
+; RUN: llc < %s -mtriple=x86_64-linux-gnu -mattr=avxneconvert,f16c | FileCheck %s --check-prefixes=CHECK,AVX,BF16,AVXNC
 
 define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; X86-LABEL: add:
@@ -22,10 +22,7 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
 ; X86-NEXT:    vmovss %xmm0, (%esp)
 ; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %eax
-; X86-NEXT:    movw %ax, (%esi)
+; X86-NEXT:    vmovsh %xmm0, (%esi)
 ; X86-NEXT:    addl $8, %esp
 ; X86-NEXT:    popl %esi
 ; X86-NEXT:    retl
@@ -42,27 +39,42 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; SSE2-NEXT:    movd %eax, %xmm0
 ; SSE2-NEXT:    addss %xmm1, %xmm0
 ; SSE2-NEXT:    callq __truncsfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
 ; SSE2-NEXT:    movw %ax, (%rbx)
 ; SSE2-NEXT:    popq %rbx
 ; SSE2-NEXT:    retq
 ;
-; AVX-LABEL: add:
-; AVX:       # %bb.0:
-; AVX-NEXT:    pushq %rbx
-; AVX-NEXT:    movq %rdx, %rbx
-; AVX-NEXT:    movzwl (%rsi), %eax
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm0
-; AVX-NEXT:    movzwl (%rdi), %eax
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm1
-; AVX-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; AVX-NEXT:    callq __truncsfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    movw %ax, (%rbx)
-; AVX-NEXT:    popq %rbx
-; AVX-NEXT:    retq
+; BF16-LABEL: add:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    movq %rdx, %rbx
+; BF16-NEXT:    movzwl (%rsi), %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    movzwl (%rdi), %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, (%rbx)
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    retq
+;
+; FP16-LABEL: add:
+; FP16:       # %bb.0:
+; FP16-NEXT:    pushq %rbx
+; FP16-NEXT:    movq %rdx, %rbx
+; FP16-NEXT:    movzwl (%rsi), %eax
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm0
+; FP16-NEXT:    movzwl (%rdi), %eax
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm1
+; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; FP16-NEXT:    callq __truncsfbf2@PLT
+; FP16-NEXT:    vmovsh %xmm0, (%rbx)
+; FP16-NEXT:    popq %rbx
+; FP16-NEXT:    retq
   %a = load bfloat, ptr %pa
   %b = load bfloat, ptr %pb
   %add = fadd bfloat %a, %b
@@ -89,8 +101,8 @@ define bfloat @add2(bfloat %a, bfloat %b) nounwind {
 ; SSE2-LABEL: add2:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    pushq %rax
-; SSE2-NEXT:    movd %xmm0, %eax
-; SSE2-NEXT:    movd %xmm1, %ecx
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
+; SSE2-NEXT:    pextrw $0, %xmm1, %ecx
 ; SSE2-NEXT:    shll $16, %ecx
 ; SSE2-NEXT:    movd %ecx, %xmm1
 ; SSE2-NEXT:    shll $16, %eax
@@ -100,19 +112,33 @@ define bfloat @add2(bfloat %a, bfloat %b) nounwind {
 ; SSE2-NEXT:    popq %rax
 ; SSE2-NEXT:    retq
 ;
-; AVX-LABEL: add2:
-; AVX:       # %bb.0:
-; AVX-NEXT:    pushq %rax
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    vmovd %xmm1, %ecx
-; AVX-NEXT:    shll $16, %ecx
-; AVX-NEXT:    vmovd %ecx, %xmm0
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm1
-; AVX-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; AVX-NEXT:    callq __truncsfbf2@PLT
-; AVX-NEXT:    popq %rax
-; AVX-NEXT:    retq
+; BF16-LABEL: add2:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rax
+; BF16-NEXT:    vpextrw $0, %xmm0, %eax
+; BF16-NEXT:    vpextrw $0, %xmm1, %ecx
+; BF16-NEXT:    shll $16, %ecx
+; BF16-NEXT:    vmovd %ecx, %xmm0
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2@PLT
+; BF16-NEXT:    popq %rax
+; BF16-NEXT:    retq
+;
+; FP16-LABEL: add2:
+; FP16:       # %bb.0:
+; FP16-NEXT:    pushq %rax
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    vmovw %xmm1, %ecx
+; FP16-NEXT:    shll $16, %ecx
+; FP16-NEXT:    vmovd %ecx, %xmm0
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm1
+; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; FP16-NEXT:    callq __truncsfbf2@PLT
+; FP16-NEXT:    popq %rax
+; FP16-NEXT:    retq
   %add = fadd bfloat %a, %b
   ret bfloat %add
 }
@@ -123,22 +149,18 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; X86-NEXT:    pushl %ebx
 ; X86-NEXT:    pushl %edi
 ; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $32, %esp
+; X86-NEXT:    subl $16, %esp
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %esi
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %ebx
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
 ; X86-NEXT:    vmovsd %xmm0, (%esp)
 ; X86-NEXT:    calll __truncdfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %edi
+; X86-NEXT:    vmovw %xmm0, %edi
 ; X86-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
 ; X86-NEXT:    vmovsd %xmm0, (%esp)
 ; X86-NEXT:    calll __truncdfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %eax
+; X86-NEXT:    vmovw %xmm0, %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
 ; X86-NEXT:    shll $16, %edi
@@ -146,14 +168,12 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
 ; X86-NEXT:    vmovss %xmm0, (%esp)
 ; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %eax
+; X86-NEXT:    vmovw %xmm0, %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
 ; X86-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
 ; X86-NEXT:    vmovsd %xmm0, (%esi)
-; X86-NEXT:    addl $32, %esp
+; X86-NEXT:    addl $16, %esp
 ; X86-NEXT:    popl %esi
 ; X86-NEXT:    popl %edi
 ; X86-NEXT:    popl %ebx
@@ -168,17 +188,17 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; SSE2-NEXT:    movq %rsi, %r14
 ; SSE2-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
 ; SSE2-NEXT:    callq __truncdfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %ebp
+; SSE2-NEXT:    pextrw $0, %xmm0, %ebp
 ; SSE2-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
 ; SSE2-NEXT:    callq __truncdfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
 ; SSE2-NEXT:    shll $16, %eax
 ; SSE2-NEXT:    movd %eax, %xmm1
 ; SSE2-NEXT:    shll $16, %ebp
 ; SSE2-NEXT:    movd %ebp, %xmm0
 ; SSE2-NEXT:    addss %xmm1, %xmm0
 ; SSE2-NEXT:    callq __truncsfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
 ; SSE2-NEXT:    shll $16, %eax
 ; SSE2-NEXT:    movd %eax, %xmm0
 ; SSE2-NEXT:    cvtss2sd %xmm0, %xmm0
@@ -188,34 +208,63 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; SSE2-NEXT:    popq %rbp
 ; SSE2-NEXT:    retq
 ;
-; AVX-LABEL: add_double:
-; AVX:       # %bb.0:
-; AVX-NEXT:    pushq %rbp
-; AVX-NEXT:    pushq %r14
-; AVX-NEXT:    pushq %rbx
-; AVX-NEXT:    movq %rdx, %rbx
-; AVX-NEXT:    movq %rsi, %r14
-; AVX-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
-; AVX-NEXT:    callq __truncdfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %ebp
-; AVX-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
-; AVX-NEXT:    callq __truncdfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm0
-; AVX-NEXT:    shll $16, %ebp
-; AVX-NEXT:    vmovd %ebp, %xmm1
-; AVX-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; AVX-NEXT:    callq __truncsfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm0
-; AVX-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
-; AVX-NEXT:    vmovsd %xmm0, (%rbx)
-; AVX-NEXT:    popq %rbx
-; AVX-NEXT:    popq %r14
-; AVX-NEXT:    popq %rbp
-; AVX-NEXT:    retq
+; BF16-LABEL: add_double:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbp
+; BF16-NEXT:    pushq %r14
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    movq %rdx, %rbx
+; BF16-NEXT:    movq %rsi, %r14
+; BF16-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; BF16-NEXT:    callq __truncdfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, %ebp
+; BF16-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; BF16-NEXT:    callq __truncdfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    shll $16, %ebp
+; BF16-NEXT:    vmovd %ebp, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; BF16-NEXT:    vmovsd %xmm0, (%rbx)
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    popq %r14
+; BF16-NEXT:    popq %rbp
+; BF16-NEXT:    retq
+;
+; FP16-LABEL: add_double:
+; FP16:       # %bb.0:
+; FP16-NEXT:    pushq %rbp
+; FP16-NEXT:    pushq %r14
+; FP16-NEXT:    pushq %rbx
+; FP16-NEXT:    movq %rdx, %rbx
+; FP16-NEXT:    movq %rsi, %r14
+; FP16-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
+; FP16-NEXT:    callq __truncdfbf2@PLT
+; FP16-NEXT:    vmovw %xmm0, %ebp
+; FP16-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
+; FP16-NEXT:    callq __truncdfbf2@PLT
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm0
+; FP16-NEXT:    shll $16, %ebp
+; FP16-NEXT:    vmovd %ebp, %xmm1
+; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; FP16-NEXT:    callq __truncsfbf2@PLT
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm0
+; FP16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; FP16-NEXT:    vmovsd %xmm0, (%rbx)
+; FP16-NEXT:    popq %rbx
+; FP16-NEXT:    popq %r14
+; FP16-NEXT:    popq %rbp
+; FP16-NEXT:    retq
   %la = load double, ptr %pa
   %a = fptrunc double %la to bfloat
   %lb = load double, ptr %pb
@@ -230,19 +279,15 @@ define double @add_double2(double %da, double %db) nounwind {
 ; X86-LABEL: add_double2:
 ; X86:       # %bb.0:
 ; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $40, %esp
+; X86-NEXT:    subl $24, %esp
 ; X86-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
 ; X86-NEXT:    vmovsd %xmm0, (%esp)
 ; X86-NEXT:    calll __truncdfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %esi
+; X86-NEXT:    vmovw %xmm0, %esi
 ; X86-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
 ; X86-NEXT:    vmovsd %xmm0, (%esp)
 ; X86-NEXT:    calll __truncdfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %eax
+; X86-NEXT:    vmovw %xmm0, %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
 ; X86-NEXT:    shll $16, %esi
@@ -250,15 +295,13 @@ define double @add_double2(double %da, double %db) nounwind {
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
 ; X86-NEXT:    vmovss %xmm0, (%esp)
 ; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    fstps {{[0-9]+}}(%esp)
-; X86-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovd %xmm0, %eax
+; X86-NEXT:    vmovw %xmm0, %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
 ; X86-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
 ; X86-NEXT:    vmovsd %xmm0, {{[0-9]+}}(%esp)
 ; X86-NEXT:    fldl {{[0-9]+}}(%esp)
-; X86-NEXT:    addl $40, %esp
+; X86-NEXT:    addl $24, %esp
 ; X86-NEXT:    popl %esi
 ; X86-NEXT:    retl
 ;
@@ -268,18 +311,18 @@ define double @add_double2(double %da, double %db) nounwind {
 ; SSE2-NEXT:    subq $16, %rsp
 ; SSE2-NEXT:    movsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
 ; SSE2-NEXT:    callq __truncdfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %ebx
+; SSE2-NEXT:    pextrw $0, %xmm0, %ebx
 ; SSE2-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
 ; SSE2-NEXT:    # xmm0 = mem[0],zero
 ; SSE2-NEXT:    callq __truncdfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
 ; SSE2-NEXT:    shll $16, %eax
 ; SSE2-NEXT:    movd %eax, %xmm1
 ; SSE2-NEXT:    shll $16, %ebx
 ; SSE2-NEXT:    movd %ebx, %xmm0
 ; SSE2-NEXT:    addss %xmm1, %xmm0
 ; SSE2-NEXT:    callq __truncsfbf2@PLT
-; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    pextrw $0, %xmm0, %eax
 ; SSE2-NEXT:    shll $16, %eax
 ; SSE2-NEXT:    movd %eax, %xmm0
 ; SSE2-NEXT:    cvtss2sd %xmm0, %xmm0
@@ -287,30 +330,55 @@ define double @add_double2(double %da, double %db) nounwind {
 ; SSE2-NEXT:    popq %rbx
 ; SSE2-NEXT:    retq
 ;
-; AVX-LABEL: add_double2:
-; AVX:       # %bb.0:
-; AVX-NEXT:    pushq %rbx
-; AVX-NEXT:    subq $16, %rsp
-; AVX-NEXT:    vmovsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; AVX-NEXT:    callq __truncdfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %ebx
-; AVX-NEXT:    vmovq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
-; AVX-NEXT:    # xmm0 = mem[0],zero
-; AVX-NEXT:    callq __truncdfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm0
-; AVX-NEXT:    shll $16, %ebx
-; AVX-NEXT:    vmovd %ebx, %xmm1
-; AVX-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; AVX-NEXT:    callq __truncsfbf2@PLT
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    shll $16, %eax
-; AVX-NEXT:    vmovd %eax, %xmm0
-; AVX-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
-; AVX-NEXT:    addq $16, %rsp
-; AVX-NEXT:    popq %rbx
-; AVX-NEXT:    retq
+; BF16-LABEL: add_double2:
+; BF16:       # %bb.0:
+; BF16-NEXT:    pushq %rbx
+; BF16-NEXT:    subq $16, %rsp
+; BF16-NEXT:    vmovsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; BF16-NEXT:    callq __truncdfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, %ebx
+; BF16-NEXT:    vmovq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
+; BF16-NEXT:    # xmm0 = mem[0],zero
+; BF16-NEXT:    callq __truncdfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    shll $16, %ebx
+; BF16-NEXT:    vmovd %ebx, %xmm1
+; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; BF16-NEXT:    callq __truncsfbf2@PLT
+; BF16-NEXT:    vpextrw $0, %xmm0, %eax
+; BF16-NEXT:    shll $16, %eax
+; BF16-NEXT:    vmovd %eax, %xmm0
+; BF16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; BF16-NEXT:    addq $16, %rsp
+; BF16-NEXT:    popq %rbx
+; BF16-NEXT:    retq
+;
+; FP16-LABEL: add_double2:
+; FP16:       # %bb.0:
+; FP16-NEXT:    pushq %rbx
+; FP16-NEXT:    subq $16, %rsp
+; FP16-NEXT:    vmovsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; FP16-NEXT:    callq __truncdfbf2@PLT
+; FP16-NEXT:    vmovw %xmm0, %ebx
+; FP16-NEXT:    vmovsd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Reload
+; FP16-NEXT:    # xmm0 = mem[0],zero
+; FP16-NEXT:    callq __truncdfbf2@PLT
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    shll $16, %eax
+; FP16-NEXT:    vmovd %eax, %xmm0
+; FP16-NEXT:    shll ...
[truncated]

Copy link
Contributor

@FreddyLeaf FreddyLeaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@phoebewang
Copy link
Contributor Author

Thanks @FreddyLeaf !

@phoebewang phoebewang merged commit 59af659 into llvm:main Jan 5, 2024
6 of 7 checks passed
@phoebewang phoebewang deleted the bf16 branch January 5, 2024 07:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants