Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][BF16] Improve float -> bfloat lowering under AVX512BF16 and AVXNECONVERT #78042

Merged
merged 1 commit into from
Jan 17, 2024

Conversation

phoebewang
Copy link
Contributor

@phoebewang phoebewang commented Jan 13, 2024

No description provided.

@phoebewang phoebewang changed the title [X86][BF16] Improve float -> bfloat lowering under AVX512BF16 and AVX… [X86][BF16] Improve float -> bfloat lowering under AVX512BF16 and AVXNECONVERT Jan 13, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 13, 2024

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

Changes

Patch is 43.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78042.diff

3 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+12-2)
  • (modified) llvm/lib/Target/X86/X86InstrSSE.td (+4)
  • (modified) llvm/test/CodeGen/X86/bfloat.ll (+372-485)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 700ab797b2f69f..e19128ec775651 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21523,9 +21523,19 @@ static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) {
 SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
                                            SelectionDAG &DAG) const {
   SDLoc DL(Op);
+
+  MVT SVT = Op.getOperand(0).getSimpleValueType();
+  if (SVT == MVT::f32 && (Subtarget.hasBF16() || Subtarget.hasAVXNECONVERT())) {
+    SDValue Res;
+    Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
+    Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
+    Res = DAG.getBitcast(MVT::v8i16, Res);
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i16, Res,
+                       DAG.getIntPtrConstant(0, DL));
+  }
+
   MakeLibCallOptions CallOptions;
-  RTLIB::Libcall LC =
-      RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
+  RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, MVT::bf16);
   SDValue Res =
       makeLibCall(DAG, LC, MVT::f16, Op.getOperand(0), CallOptions, DL).first;
   return DAG.getBitcast(MVT::i16, Res);
diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td
index e8a1a2b83886f8..a8cd1996eeb356 100644
--- a/llvm/lib/Target/X86/X86InstrSSE.td
+++ b/llvm/lib/Target/X86/X86InstrSSE.td
@@ -8331,6 +8331,10 @@ let Predicates = [HasAVXNECONVERT] in {
        f256mem>, T8;
   defm VCVTNEPS2BF16 : VCVTNEPS2BF16_BASE, VEX, T8, XS, ExplicitVEXPrefix;
 
+  def : Pat<(v8bf16 (X86cvtneps2bf16 (v4f32 VR128X:$src))),
+            (VCVTNEPS2BF16rr VR128:$src)>;
+  def : Pat<(v8bf16 (X86cvtneps2bf16 (loadv4f32 addr:$src))),
+            (VCVTNEPS2BF16rm addr:$src)>;
   def : Pat<(v8bf16 (X86vfpround (v8f32 VR256:$src))),
             (VCVTNEPS2BF16Yrr VR256:$src)>;
   def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index b309f47e4b7190..9d2ef51b0a8fbe 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -8,23 +8,18 @@
 define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; X86-LABEL: add:
 ; X86:       # %bb.0:
-; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $8, %esp
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %esi
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
+; X86-NEXT:    movzwl (%edx), %edx
+; X86-NEXT:    shll $16, %edx
+; X86-NEXT:    vmovd %edx, %xmm0
 ; X86-NEXT:    movzwl (%ecx), %ecx
 ; X86-NEXT:    shll $16, %ecx
-; X86-NEXT:    vmovd %ecx, %xmm0
-; X86-NEXT:    movzwl (%eax), %eax
-; X86-NEXT:    shll $16, %eax
-; X86-NEXT:    vmovd %eax, %xmm1
+; X86-NEXT:    vmovd %ecx, %xmm1
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    vmovsh %xmm0, (%esi)
-; X86-NEXT:    addl $8, %esp
-; X86-NEXT:    popl %esi
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; X86-NEXT:    vpextrw $0, %xmm0, (%eax)
 ; X86-NEXT:    retl
 ;
 ; SSE2-LABEL: add:
@@ -44,37 +39,31 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; SSE2-NEXT:    popq %rbx
 ; SSE2-NEXT:    retq
 ;
-; BF16-LABEL: add:
-; BF16:       # %bb.0:
-; BF16-NEXT:    pushq %rbx
-; BF16-NEXT:    movq %rdx, %rbx
-; BF16-NEXT:    movzwl (%rsi), %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    movzwl (%rdi), %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm1
-; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; BF16-NEXT:    callq __truncsfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, (%rbx)
-; BF16-NEXT:    popq %rbx
-; BF16-NEXT:    retq
+; F16-LABEL: add:
+; F16:       # %bb.0:
+; F16-NEXT:    movzwl (%rsi), %eax
+; F16-NEXT:    shll $16, %eax
+; F16-NEXT:    vmovd %eax, %xmm0
+; F16-NEXT:    movzwl (%rdi), %eax
+; F16-NEXT:    shll $16, %eax
+; F16-NEXT:    vmovd %eax, %xmm1
+; F16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; F16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; F16-NEXT:    vpextrw $0, %xmm0, (%rdx)
+; F16-NEXT:    retq
 ;
-; FP16-LABEL: add:
-; FP16:       # %bb.0:
-; FP16-NEXT:    pushq %rbx
-; FP16-NEXT:    movq %rdx, %rbx
-; FP16-NEXT:    movzwl (%rsi), %eax
-; FP16-NEXT:    shll $16, %eax
-; FP16-NEXT:    vmovd %eax, %xmm0
-; FP16-NEXT:    movzwl (%rdi), %eax
-; FP16-NEXT:    shll $16, %eax
-; FP16-NEXT:    vmovd %eax, %xmm1
-; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; FP16-NEXT:    callq __truncsfbf2@PLT
-; FP16-NEXT:    vmovsh %xmm0, (%rbx)
-; FP16-NEXT:    popq %rbx
-; FP16-NEXT:    retq
+; AVXNC-LABEL: add:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    movzwl (%rsi), %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    movzwl (%rdi), %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm1
+; AVXNC-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vpextrw $0, %xmm0, (%rdx)
+; AVXNC-NEXT:    retq
   %a = load bfloat, ptr %pa
   %b = load bfloat, ptr %pb
   %add = fadd bfloat %a, %b
@@ -85,7 +74,6 @@ define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind {
 define bfloat @add2(bfloat %a, bfloat %b) nounwind {
 ; X86-LABEL: add2:
 ; X86:       # %bb.0:
-; X86-NEXT:    subl $12, %esp
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
@@ -93,9 +81,9 @@ define bfloat @add2(bfloat %a, bfloat %b) nounwind {
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm1
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    addl $12, %esp
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; X86-NEXT:    vmovw %xmm0, %eax
+; X86-NEXT:    vmovw %eax, %xmm0
 ; X86-NEXT:    retl
 ;
 ; SSE2-LABEL: add2:
@@ -112,23 +100,8 @@ define bfloat @add2(bfloat %a, bfloat %b) nounwind {
 ; SSE2-NEXT:    popq %rax
 ; SSE2-NEXT:    retq
 ;
-; BF16-LABEL: add2:
-; BF16:       # %bb.0:
-; BF16-NEXT:    pushq %rax
-; BF16-NEXT:    vpextrw $0, %xmm0, %eax
-; BF16-NEXT:    vpextrw $0, %xmm1, %ecx
-; BF16-NEXT:    shll $16, %ecx
-; BF16-NEXT:    vmovd %ecx, %xmm0
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm1
-; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; BF16-NEXT:    callq __truncsfbf2@PLT
-; BF16-NEXT:    popq %rax
-; BF16-NEXT:    retq
-;
 ; FP16-LABEL: add2:
 ; FP16:       # %bb.0:
-; FP16-NEXT:    pushq %rax
 ; FP16-NEXT:    vmovw %xmm0, %eax
 ; FP16-NEXT:    vmovw %xmm1, %ecx
 ; FP16-NEXT:    shll $16, %ecx
@@ -136,9 +109,24 @@ define bfloat @add2(bfloat %a, bfloat %b) nounwind {
 ; FP16-NEXT:    shll $16, %eax
 ; FP16-NEXT:    vmovd %eax, %xmm1
 ; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; FP16-NEXT:    callq __truncsfbf2@PLT
-; FP16-NEXT:    popq %rax
+; FP16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    vmovw %eax, %xmm0
 ; FP16-NEXT:    retq
+;
+; AVXNC-LABEL: add2:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    vpextrw $0, %xmm0, %eax
+; AVXNC-NEXT:    vpextrw $0, %xmm1, %ecx
+; AVXNC-NEXT:    shll $16, %ecx
+; AVXNC-NEXT:    vmovd %ecx, %xmm0
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm1
+; AVXNC-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vmovd %xmm0, %eax
+; AVXNC-NEXT:    vpinsrw $0, %eax, %xmm0, %xmm0
+; AVXNC-NEXT:    retq
   %add = fadd bfloat %a, %b
   ret bfloat %add
 }
@@ -166,8 +154,7 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; X86-NEXT:    shll $16, %edi
 ; X86-NEXT:    vmovd %edi, %xmm1
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
 ; X86-NEXT:    vmovw %xmm0, %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
@@ -208,35 +195,6 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; SSE2-NEXT:    popq %rbp
 ; SSE2-NEXT:    retq
 ;
-; BF16-LABEL: add_double:
-; BF16:       # %bb.0:
-; BF16-NEXT:    pushq %rbp
-; BF16-NEXT:    pushq %r14
-; BF16-NEXT:    pushq %rbx
-; BF16-NEXT:    movq %rdx, %rbx
-; BF16-NEXT:    movq %rsi, %r14
-; BF16-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
-; BF16-NEXT:    callq __truncdfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, %ebp
-; BF16-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
-; BF16-NEXT:    callq __truncdfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    shll $16, %ebp
-; BF16-NEXT:    vmovd %ebp, %xmm1
-; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; BF16-NEXT:    callq __truncsfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
-; BF16-NEXT:    vmovsd %xmm0, (%rbx)
-; BF16-NEXT:    popq %rbx
-; BF16-NEXT:    popq %r14
-; BF16-NEXT:    popq %rbp
-; BF16-NEXT:    retq
-;
 ; FP16-LABEL: add_double:
 ; FP16:       # %bb.0:
 ; FP16-NEXT:    pushq %rbp
@@ -255,7 +213,7 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; FP16-NEXT:    shll $16, %ebp
 ; FP16-NEXT:    vmovd %ebp, %xmm1
 ; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; FP16-NEXT:    callq __truncsfbf2@PLT
+; FP16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
 ; FP16-NEXT:    vmovw %xmm0, %eax
 ; FP16-NEXT:    shll $16, %eax
 ; FP16-NEXT:    vmovd %eax, %xmm0
@@ -265,6 +223,35 @@ define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind {
 ; FP16-NEXT:    popq %r14
 ; FP16-NEXT:    popq %rbp
 ; FP16-NEXT:    retq
+;
+; AVXNC-LABEL: add_double:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    pushq %rbp
+; AVXNC-NEXT:    pushq %r14
+; AVXNC-NEXT:    pushq %rbx
+; AVXNC-NEXT:    movq %rdx, %rbx
+; AVXNC-NEXT:    movq %rsi, %r14
+; AVXNC-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVXNC-NEXT:    callq __truncdfbf2@PLT
+; AVXNC-NEXT:    vpextrw $0, %xmm0, %ebp
+; AVXNC-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVXNC-NEXT:    callq __truncdfbf2@PLT
+; AVXNC-NEXT:    vpextrw $0, %xmm0, %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    shll $16, %ebp
+; AVXNC-NEXT:    vmovd %ebp, %xmm1
+; AVXNC-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vmovd %xmm0, %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; AVXNC-NEXT:    vmovsd %xmm0, (%rbx)
+; AVXNC-NEXT:    popq %rbx
+; AVXNC-NEXT:    popq %r14
+; AVXNC-NEXT:    popq %rbp
+; AVXNC-NEXT:    retq
   %la = load double, ptr %pa
   %a = fptrunc double %la to bfloat
   %lb = load double, ptr %pb
@@ -293,8 +280,7 @@ define double @add_double2(double %da, double %db) nounwind {
 ; X86-NEXT:    shll $16, %esi
 ; X86-NEXT:    vmovd %esi, %xmm1
 ; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    calll __truncsfbf2
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
 ; X86-NEXT:    vmovw %xmm0, %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
@@ -330,31 +316,6 @@ define double @add_double2(double %da, double %db) nounwind {
 ; SSE2-NEXT:    popq %rbx
 ; SSE2-NEXT:    retq
 ;
-; BF16-LABEL: add_double2:
-; BF16:       # %bb.0:
-; BF16-NEXT:    pushq %rbx
-; BF16-NEXT:    subq $16, %rsp
-; BF16-NEXT:    vmovsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; BF16-NEXT:    callq __truncdfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, %ebx
-; BF16-NEXT:    vmovq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
-; BF16-NEXT:    # xmm0 = mem[0],zero
-; BF16-NEXT:    callq __truncdfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    shll $16, %ebx
-; BF16-NEXT:    vmovd %ebx, %xmm1
-; BF16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; BF16-NEXT:    callq __truncsfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
-; BF16-NEXT:    addq $16, %rsp
-; BF16-NEXT:    popq %rbx
-; BF16-NEXT:    retq
-;
 ; FP16-LABEL: add_double2:
 ; FP16:       # %bb.0:
 ; FP16-NEXT:    pushq %rbx
@@ -371,7 +332,7 @@ define double @add_double2(double %da, double %db) nounwind {
 ; FP16-NEXT:    shll $16, %ebx
 ; FP16-NEXT:    vmovd %ebx, %xmm1
 ; FP16-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; FP16-NEXT:    callq __truncsfbf2@PLT
+; FP16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
 ; FP16-NEXT:    vmovw %xmm0, %eax
 ; FP16-NEXT:    shll $16, %eax
 ; FP16-NEXT:    vmovd %eax, %xmm0
@@ -379,6 +340,31 @@ define double @add_double2(double %da, double %db) nounwind {
 ; FP16-NEXT:    addq $16, %rsp
 ; FP16-NEXT:    popq %rbx
 ; FP16-NEXT:    retq
+;
+; AVXNC-LABEL: add_double2:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    pushq %rbx
+; AVXNC-NEXT:    subq $16, %rsp
+; AVXNC-NEXT:    vmovsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
+; AVXNC-NEXT:    callq __truncdfbf2@PLT
+; AVXNC-NEXT:    vpextrw $0, %xmm0, %ebx
+; AVXNC-NEXT:    vmovq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload
+; AVXNC-NEXT:    # xmm0 = mem[0],zero
+; AVXNC-NEXT:    callq __truncdfbf2@PLT
+; AVXNC-NEXT:    vpextrw $0, %xmm0, %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    shll $16, %ebx
+; AVXNC-NEXT:    vmovd %ebx, %xmm1
+; AVXNC-NEXT:    vaddss %xmm0, %xmm1, %xmm0
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vmovd %xmm0, %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
+; AVXNC-NEXT:    addq $16, %rsp
+; AVXNC-NEXT:    popq %rbx
+; AVXNC-NEXT:    retq
   %a = fptrunc double %da to bfloat
   %b = fptrunc double %db to bfloat
   %add = fadd bfloat %a, %b
@@ -389,19 +375,14 @@ define double @add_double2(double %da, double %db) nounwind {
 define void @add_constant(ptr %pa, ptr %pc) nounwind {
 ; X86-LABEL: add_constant:
 ; X86:       # %bb.0:
-; X86-NEXT:    pushl %esi
-; X86-NEXT:    subl $8, %esp
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %esi
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    movzwl (%eax), %eax
-; X86-NEXT:    shll $16, %eax
-; X86-NEXT:    vmovd %eax, %xmm0
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    movzwl (%ecx), %ecx
+; X86-NEXT:    shll $16, %ecx
+; X86-NEXT:    vmovd %ecx, %xmm0
 ; X86-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    vmovsh %xmm0, (%esi)
-; X86-NEXT:    addl $8, %esp
-; X86-NEXT:    popl %esi
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; X86-NEXT:    vpextrw $0, %xmm0, (%eax)
 ; X86-NEXT:    retl
 ;
 ; SSE2-LABEL: add_constant:
@@ -418,31 +399,25 @@ define void @add_constant(ptr %pa, ptr %pc) nounwind {
 ; SSE2-NEXT:    popq %rbx
 ; SSE2-NEXT:    retq
 ;
-; BF16-LABEL: add_constant:
-; BF16:       # %bb.0:
-; BF16-NEXT:    pushq %rbx
-; BF16-NEXT:    movq %rsi, %rbx
-; BF16-NEXT:    movzwl (%rdi), %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; BF16-NEXT:    callq __truncsfbf2@PLT
-; BF16-NEXT:    vpextrw $0, %xmm0, (%rbx)
-; BF16-NEXT:    popq %rbx
-; BF16-NEXT:    retq
+; F16-LABEL: add_constant:
+; F16:       # %bb.0:
+; F16-NEXT:    movzwl (%rdi), %eax
+; F16-NEXT:    shll $16, %eax
+; F16-NEXT:    vmovd %eax, %xmm0
+; F16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; F16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; F16-NEXT:    vpextrw $0, %xmm0, (%rsi)
+; F16-NEXT:    retq
 ;
-; FP16-LABEL: add_constant:
-; FP16:       # %bb.0:
-; FP16-NEXT:    pushq %rbx
-; FP16-NEXT:    movq %rsi, %rbx
-; FP16-NEXT:    movzwl (%rdi), %eax
-; FP16-NEXT:    shll $16, %eax
-; FP16-NEXT:    vmovd %eax, %xmm0
-; FP16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; FP16-NEXT:    callq __truncsfbf2@PLT
-; FP16-NEXT:    vmovsh %xmm0, (%rbx)
-; FP16-NEXT:    popq %rbx
-; FP16-NEXT:    retq
+; AVXNC-LABEL: add_constant:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    movzwl (%rdi), %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vpextrw $0, %xmm0, (%rsi)
+; AVXNC-NEXT:    retq
   %a = load bfloat, ptr %pa
   %add = fadd bfloat %a, 1.0
   store bfloat %add, ptr %pc
@@ -452,14 +427,13 @@ define void @add_constant(ptr %pa, ptr %pc) nounwind {
 define bfloat @add_constant2(bfloat %a) nounwind {
 ; X86-LABEL: add_constant2:
 ; X86:       # %bb.0:
-; X86-NEXT:    subl $12, %esp
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm0
 ; X86-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    addl $12, %esp
+; X86-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; X86-NEXT:    vmovw %xmm0, %eax
+; X86-NEXT:    vmovw %eax, %xmm0
 ; X86-NEXT:    retl
 ;
 ; SSE2-LABEL: add_constant2:
@@ -473,27 +447,27 @@ define bfloat @add_constant2(bfloat %a) nounwind {
 ; SSE2-NEXT:    popq %rax
 ; SSE2-NEXT:    retq
 ;
-; BF16-LABEL: add_constant2:
-; BF16:       # %bb.0:
-; BF16-NEXT:    pushq %rax
-; BF16-NEXT:    vpextrw $0, %xmm0, %eax
-; BF16-NEXT:    shll $16, %eax
-; BF16-NEXT:    vmovd %eax, %xmm0
-; BF16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; BF16-NEXT:    callq __truncsfbf2@PLT
-; BF16-NEXT:    popq %rax
-; BF16-NEXT:    retq
-;
 ; FP16-LABEL: add_constant2:
 ; FP16:       # %bb.0:
-; FP16-NEXT:    pushq %rax
 ; FP16-NEXT:    vmovw %xmm0, %eax
 ; FP16-NEXT:    shll $16, %eax
 ; FP16-NEXT:    vmovd %eax, %xmm0
 ; FP16-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; FP16-NEXT:    callq __truncsfbf2@PLT
-; FP16-NEXT:    popq %rax
+; FP16-NEXT:    vcvtneps2bf16 %xmm0, %xmm0
+; FP16-NEXT:    vmovw %xmm0, %eax
+; FP16-NEXT:    vmovw %eax, %xmm0
 ; FP16-NEXT:    retq
+;
+; AVXNC-LABEL: add_constant2:
+; AVXNC:       # %bb.0:
+; AVXNC-NEXT:    vpextrw $0, %xmm0, %eax
+; AVXNC-NEXT:    shll $16, %eax
+; AVXNC-NEXT:    vmovd %eax, %xmm0
+; AVXNC-NEXT:    vaddss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVXNC-NEXT:    {vex} vcvtneps2bf16 %xmm0, %xmm0
+; AVXNC-NEXT:    vmovd %xmm0, %eax
+; AVXNC-NEXT:    vpinsrw $0, %eax, %xmm0, %xmm0
+; AVXNC-NEXT:    retq
   %add = fadd bfloat %a, 1.0
   ret bfloat %add
 }
@@ -551,138 +525,101 @@ define bfloat @fold_ext_trunc2(bfloat %a) nounwind {
 define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
 ; X86-LABEL: addv:
 ; X86:       # %bb.0:
-; X86-NEXT:    subl $172, %esp
+; X86-NEXT:    pushl %ebp
+; X86-NEXT:    pushl %ebx
+; X86-NEXT:    pushl %edi
+; X86-NEXT:    pushl %esi
 ; X86-NEXT:    vmovw %xmm1, %eax
-; X86-NEXT:    vmovdqa %xmm1, %xmm3
-; X86-NEXT:    vmovdqa %xmm1, {{[-0-9]+}}(%e{{[sb]}}p) # 16-byte Spill
 ; X86-NEXT:    shll $16, %eax
 ; X86-NEXT:    vmovd %eax, %xmm2
 ; X86-NEXT:    vmovw %xmm0, %eax
-; X86-NEXT:    vmovdqa %xmm0, %xmm4
-; X86-NEXT:    vmovdqa %xmm0, {{[-0-9]+}}(%e{{[sb]}}p) # 16-byte Spill
-; X86-NEXT:    shll $16, %eax
-; X86-NEXT:    vmovd %eax, %xmm1
-; X86-NEXT:    vaddss %xmm2, %xmm1, %xmm0
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    vpextrw $1, %xmm3, %eax
-; X86-NEXT:    shll $16, %eax
-; X86-NEXT:    vmovd %eax, %xmm0
-; X86-NEXT:    vpextrw $1, %xmm4, %eax
 ; X86-NEXT:    shll $16, %eax
-; X86-NEXT:    vmovd %eax, %xmm1
-; X86-NEXT:    vaddss %xmm0, %xmm1, %xmm0
-; X86-NEXT:    vmovss %xmm0, {{[-0-9]+}}(%e{{[sb]}}p) # 4-byte Spill
-; X86-NEXT:    calll __truncsfbf2
-; X86-NEXT:    vmovaps %xmm0, {{[-0-9]+}}(%e{{[sb]}}p) # 16-byte Spill
-; X86-NEXT:    vmovss {{[-0-9]+}}(%e{{[sb]}}p), %xmm0 # 4-byte Reload
-; X86-NEXT:    # xmm0 = mem[0],zero,zero,zero
-; X86-NEXT:    vmovss %xmm0, (%esp)
-; X86-NEXT:    vmovdqa {{[-0-9]+}}(%e{{[sb]}}p), %xmm0 # 16-byte Reload
-; X86-NEXT:    vpextrw $2, %xmm0, %eax
+; X86-NEXT:    vmovd %eax, %xmm3
+; X86-NEXT:    vaddss %xmm2, %xmm3, %xmm2
+; X86-NEXT:    vcvtneps2bf16 %xmm2, %xmm2
+; X86-NEXT:    vmovw %xmm2, %ecx
+; X86-NEXT:    vpextrw $1, %xmm1, %eax
 ; X86-NEXT:    shll $16, %eax
-; X86-NEXT: ...
[truncated]

Copy link
Contributor

@FreddyLeaf FreddyLeaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@phoebewang phoebewang merged commit 9745c13 into llvm:main Jan 17, 2024
4 of 5 checks passed
@phoebewang phoebewang deleted the bf16 branch January 17, 2024 02:09
@phoebewang
Copy link
Contributor Author

Thanks @FreddyLeaf !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants