Skip to content

Commit

Permalink
[X86][BF16] Improve vectorization of BF16 (#88486)
Browse files Browse the repository at this point in the history
1. Move expansion to combineFP_EXTEND to help with small vectors;
2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
  • Loading branch information
phoebewang committed Apr 13, 2024
1 parent 37ebf2a commit 3cf8535
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 229 deletions.
53 changes: 29 additions & 24 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21433,25 +21433,9 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
return Res;
}

if (!SVT.isVector())
if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
return Op;

if (SVT.getVectorElementType() == MVT::bf16) {
// FIXME: Do we need to support strict FP?
assert(!IsStrict && "Strict FP doesn't support BF16");
if (VT.getVectorElementType() == MVT::f64) {
MVT TmpVT = VT.changeVectorElementType(MVT::f32);
return DAG.getNode(ISD::FP_EXTEND, DL, VT,
DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
}
assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
MVT NVT = SVT.changeVectorElementType(MVT::i32);
In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(16, DL, NVT));
return DAG.getBitcast(VT, In);
}

if (SVT.getVectorElementType() == MVT::f16) {
if (Subtarget.hasFP16() && isTypeLegal(SVT))
return Op;
Expand Down Expand Up @@ -56517,17 +56501,40 @@ static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,

static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
bool IsStrict = N->isStrictFPOpcode();
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
EVT SrcVT = Src.getValueType();

SDLoc dl(N);
if (SrcVT.getScalarType() == MVT::bf16) {
if (!IsStrict && Src.getOpcode() == ISD::FP_ROUND &&
Src.getOperand(0).getValueType() == VT)
return Src.getOperand(0);

if (!SrcVT.isVector())
return SDValue();

assert(!IsStrict && "Strict FP doesn't support BF16");
if (VT.getVectorElementType() == MVT::f64) {
MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32);
return DAG.getNode(ISD::FP_EXTEND, dl, VT,
DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
}
assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
MVT NVT = SrcVT.getSimpleVT().changeVectorElementType(MVT::i32);
Src = DAG.getBitcast(SrcVT.changeTypeToInteger(), Src);
Src = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Src);
Src = DAG.getNode(ISD::SHL, dl, NVT, Src, DAG.getConstant(16, dl, NVT));
return DAG.getBitcast(VT, Src);
}

if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
return SDValue();

if (Subtarget.hasFP16())
return SDValue();

bool IsStrict = N->isStrictFPOpcode();
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
EVT SrcVT = Src.getValueType();

if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
return SDValue();

Expand All @@ -56539,8 +56546,6 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
if (NumElts == 1 || !isPowerOf2_32(NumElts))
return SDValue();

SDLoc dl(N);

// Convert the input to vXi16.
EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
Src = DAG.getBitcast(IntVT, Src);
Expand Down
234 changes: 33 additions & 201 deletions llvm/test/CodeGen/X86/bfloat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1629,22 +1629,8 @@ define <4 x float> @pr64460_1(<4 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_1:
; SSE2: # %bb.0:
; SSE2-NEXT: pextrw $1, %xmm0, %eax
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm2
; SSE2-NEXT: movd %xmm0, %eax
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm1
; SSE2-NEXT: pextrw $3, %xmm0, %eax
; SSE2-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm2
; SSE2-NEXT: movd %xmm0, %eax
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm0
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE2-NEXT: pxor %xmm1, %xmm1
; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: retq
;
Expand All @@ -1666,41 +1652,11 @@ define <8 x float> @pr64460_2(<8 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_2:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm0, %rdx
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %rcx
; SSE2-NEXT: movq %rcx, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rdx, %rsi
; SSE2-NEXT: shrq $32, %rsi
; SSE2-NEXT: movl %edx, %edi
; SSE2-NEXT: andl $-65536, %edi # imm = 0xFFFF0000
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: movl %edx, %edi
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm0
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: shrq $48, %rdx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm1
; SSE2-NEXT: shll $16, %esi
; SSE2-NEXT: movd %esi, %xmm2
; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: andl $-65536, %edx # imm = 0xFFFF0000
; SSE2-NEXT: movd %edx, %xmm2
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm1
; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
; SSE2-NEXT: shrq $48, %rcx
; SSE2-NEXT: shll $16, %ecx
; SSE2-NEXT: movd %ecx, %xmm2
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm3
; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm2[0],xmm3[1],xmm2[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
; SSE2-NEXT: pxor %xmm1, %xmm1
; SSE2-NEXT: pxor %xmm2, %xmm2
; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
; SSE2-NEXT: movdqa %xmm2, %xmm0
; SSE2-NEXT: retq
;
; AVX-LABEL: pr64460_2:
Expand All @@ -1721,76 +1677,16 @@ define <16 x float> @pr64460_3(<16 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_3:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm1, %rdi
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
; SSE2-NEXT: movq %xmm1, %rcx
; SSE2-NEXT: movq %rcx, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %xmm0, %r9
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %rsi
; SSE2-NEXT: movq %rsi, %rdx
; SSE2-NEXT: shrq $32, %rdx
; SSE2-NEXT: movq %rdi, %r8
; SSE2-NEXT: shrq $32, %r8
; SSE2-NEXT: movq %r9, %r10
; SSE2-NEXT: shrq $32, %r10
; SSE2-NEXT: movl %r9d, %r11d
; SSE2-NEXT: andl $-65536, %r11d # imm = 0xFFFF0000
; SSE2-NEXT: movd %r11d, %xmm1
; SSE2-NEXT: movl %r9d, %r11d
; SSE2-NEXT: shll $16, %r11d
; SSE2-NEXT: movd %r11d, %xmm0
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: shrq $48, %r9
; SSE2-NEXT: shll $16, %r9d
; SSE2-NEXT: movd %r9d, %xmm1
; SSE2-NEXT: shll $16, %r10d
; SSE2-NEXT: movd %r10d, %xmm2
; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
; SSE2-NEXT: movl %edi, %r9d
; SSE2-NEXT: andl $-65536, %r9d # imm = 0xFFFF0000
; SSE2-NEXT: movd %r9d, %xmm1
; SSE2-NEXT: movl %edi, %r9d
; SSE2-NEXT: shll $16, %r9d
; SSE2-NEXT: movd %r9d, %xmm2
; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
; SSE2-NEXT: shrq $48, %rdi
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: shll $16, %r8d
; SSE2-NEXT: movd %r8d, %xmm3
; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
; SSE2-NEXT: movl %esi, %edi
; SSE2-NEXT: andl $-65536, %edi # imm = 0xFFFF0000
; SSE2-NEXT: movd %edi, %xmm3
; SSE2-NEXT: movl %esi, %edi
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1]
; SSE2-NEXT: shrq $48, %rsi
; SSE2-NEXT: shll $16, %esi
; SSE2-NEXT: movd %esi, %xmm3
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm4
; SSE2-NEXT: punpckldq {{.*#+}} xmm4 = xmm4[0],xmm3[0],xmm4[1],xmm3[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm4[0]
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: andl $-65536, %edx # imm = 0xFFFF0000
; SSE2-NEXT: movd %edx, %xmm4
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm3
; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1]
; SSE2-NEXT: shrq $48, %rcx
; SSE2-NEXT: shll $16, %ecx
; SSE2-NEXT: movd %ecx, %xmm4
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm5
; SSE2-NEXT: punpckldq {{.*#+}} xmm5 = xmm5[0],xmm4[0],xmm5[1],xmm4[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm5[0]
; SSE2-NEXT: pxor %xmm3, %xmm3
; SSE2-NEXT: pxor %xmm5, %xmm5
; SSE2-NEXT: punpcklwd {{.*#+}} xmm5 = xmm5[0],xmm0[0],xmm5[1],xmm0[1],xmm5[2],xmm0[2],xmm5[3],xmm0[3]
; SSE2-NEXT: pxor %xmm4, %xmm4
; SSE2-NEXT: punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm0[4],xmm4[5],xmm0[5],xmm4[6],xmm0[6],xmm4[7],xmm0[7]
; SSE2-NEXT: pxor %xmm2, %xmm2
; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7]
; SSE2-NEXT: movdqa %xmm5, %xmm0
; SSE2-NEXT: movdqa %xmm4, %xmm1
; SSE2-NEXT: retq
;
; F16-LABEL: pr64460_3:
Expand Down Expand Up @@ -1822,47 +1718,17 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_4:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm0, %rsi
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %rdx
; SSE2-NEXT: movq %rdx, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rdx, %rcx
; SSE2-NEXT: shrq $48, %rcx
; SSE2-NEXT: movq %rsi, %rdi
; SSE2-NEXT: shrq $32, %rdi
; SSE2-NEXT: movq %rsi, %r8
; SSE2-NEXT: shrq $48, %r8
; SSE2-NEXT: movl %esi, %r9d
; SSE2-NEXT: andl $-65536, %r9d # imm = 0xFFFF0000
; SSE2-NEXT: movd %r9d, %xmm0
; SSE2-NEXT: cvtss2sd %xmm0, %xmm1
; SSE2-NEXT: shll $16, %esi
; SSE2-NEXT: movd %esi, %xmm0
; SSE2-NEXT: cvtss2sd %xmm0, %xmm0
; SSE2-NEXT: movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
; SSE2-NEXT: shll $16, %r8d
; SSE2-NEXT: movd %r8d, %xmm1
; SSE2-NEXT: cvtss2sd %xmm1, %xmm2
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: cvtss2sd %xmm1, %xmm1
; SSE2-NEXT: movlhps {{.*#+}} xmm1 = xmm1[0],xmm2[0]
; SSE2-NEXT: movl %edx, %esi
; SSE2-NEXT: andl $-65536, %esi # imm = 0xFFFF0000
; SSE2-NEXT: movd %esi, %xmm2
; SSE2-NEXT: cvtss2sd %xmm2, %xmm3
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm2
; SSE2-NEXT: cvtss2sd %xmm2, %xmm2
; SSE2-NEXT: movlhps {{.*#+}} xmm2 = xmm2[0],xmm3[0]
; SSE2-NEXT: shll $16, %ecx
; SSE2-NEXT: movd %ecx, %xmm3
; SSE2-NEXT: cvtss2sd %xmm3, %xmm4
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm3
; SSE2-NEXT: cvtss2sd %xmm3, %xmm3
; SSE2-NEXT: movlhps {{.*#+}} xmm3 = xmm3[0],xmm4[0]
; SSE2-NEXT: pxor %xmm3, %xmm3
; SSE2-NEXT: pxor %xmm1, %xmm1
; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
; SSE2-NEXT: cvtps2pd %xmm1, %xmm4
; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
; SSE2-NEXT: cvtps2pd %xmm3, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
; SSE2-NEXT: cvtps2pd %xmm0, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm3[2,3,2,3]
; SSE2-NEXT: cvtps2pd %xmm0, %xmm3
; SSE2-NEXT: movaps %xmm4, %xmm0
; SSE2-NEXT: retq
;
; F16-LABEL: pr64460_4:
Expand All @@ -1874,45 +1740,11 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
;
; AVXNC-LABEL: pr64460_4:
; AVXNC: # %bb.0:
; AVXNC-NEXT: vpextrw $3, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm1
; AVXNC-NEXT: vcvtss2sd %xmm1, %xmm1, %xmm1
; AVXNC-NEXT: vpextrw $2, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm2
; AVXNC-NEXT: vcvtss2sd %xmm2, %xmm2, %xmm2
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm1 = xmm2[0],xmm1[0]
; AVXNC-NEXT: vpextrw $1, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm2
; AVXNC-NEXT: vcvtss2sd %xmm2, %xmm2, %xmm2
; AVXNC-NEXT: vmovd %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm3
; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm2 = xmm3[0],xmm2[0]
; AVXNC-NEXT: vinsertf128 $1, %xmm1, %ymm2, %ymm2
; AVXNC-NEXT: vpextrw $7, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm1
; AVXNC-NEXT: vcvtss2sd %xmm1, %xmm1, %xmm1
; AVXNC-NEXT: vpextrw $6, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm3
; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm1 = xmm3[0],xmm1[0]
; AVXNC-NEXT: vpextrw $5, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm3
; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
; AVXNC-NEXT: vpextrw $4, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm0
; AVXNC-NEXT: vcvtss2sd %xmm0, %xmm0, %xmm0
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm0 = xmm0[0],xmm3[0]
; AVXNC-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm1
; AVXNC-NEXT: vmovaps %ymm2, %ymm0
; AVXNC-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
; AVXNC-NEXT: vpslld $16, %ymm0, %ymm1
; AVXNC-NEXT: vcvtps2pd %xmm1, %ymm0
; AVXNC-NEXT: vextracti128 $1, %ymm1, %xmm1
; AVXNC-NEXT: vcvtps2pd %xmm1, %ymm1
; AVXNC-NEXT: retq
%b = fpext <8 x bfloat> %a to <8 x double>
ret <8 x double> %b
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ define void @test(<2 x ptr> %ptr) {
; CHECK-NEXT: # %bb.2: # %loop.127.preheader
; CHECK-NEXT: retq
; CHECK-NEXT: .LBB0_1: # %ifmerge.89
; CHECK-NEXT: movzwl (%rax), %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: vmovd %eax, %xmm0
; CHECK-NEXT: vmulss %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vbroadcastss %xmm0, %xmm0
; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1
; CHECK-NEXT: vpbroadcastw (%rax), %xmm2
; CHECK-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
; CHECK-NEXT: vmulps %xmm1, %xmm0, %xmm0
; CHECK-NEXT: vmovlps %xmm0, (%rax)
entry:
br label %then.13
Expand Down

15 comments on commit 3cf8535

@amagnigoogle
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an heads-up, we are seeing internal tests fail in bf16 due to numerical differences after this commit.
The differences are likely due to the change in the function combineFP_EXTEND.
I am not an x86/floating point expert myself but I wanted to bring up the problem earlier as we are working on a reproducer.

@Sterling-Augustine
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgive my ignorance here, but we are having some trouble reducing the differences we see with this change into a small test case.

Shoud we expect any precision differences after this change? If so, what are they?

Someone internally suggests that this change may just truncate from f32 to bf16, instead of preserving the "round-to-nearest-even" semantics during the conversion.

The algorithm is described here: https://github.com/DIPlib/diplib/blob/master/dependencies/eigen3/Eigen/src/Core/arch/Default/BFloat16.h, starting on line 283.

Any help would be greatly appreciated.

@phoebewang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch is to improve both performance and intermediate precision. It is expected to see final numerical difference, but it might be arguable which one is more deserved. BF16 is not a IEEE type, and considering the fewer fraction bits, improving the intermediate precision and allowing final numerical difference is a better orientation to me.

BF16 conversion instruction doesn't support rouding mode other than "round-to-nearest-even". In fact, if we consider the difference with a similar type FP16 from the instruction design perspective, we can see the similar thing here. VCVTPS2PH(X) supports exceptions, different rounding modes and DAZ (no FTZ compared to FP32/64) while VCVTNEPS2BF16 supports none of them. It implies exact numerical result may not be a concern to BF16.

@rmlarsen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang What do you mean by improving the intermediate precision? The lack of standardization of rounding modes does not imply that numerical accuracy is not a concern - that would be a dangerous conclusion. In any case, if the new code implements RTNE that is desirable. Does it support conversion to and from subnormal bfloat16 values?

@rmlarsen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang could you please specify precisely how this change altered the conversion semantics?

@phoebewang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmlarsen the change doesn't intend to alter conversion semantics but the intermediate calculation, for example if we have

__bf16 d = a + b + c; // where a, b and c are all __bf16 type.

before this change, it's

__bf16 d = (__bf16)((float)(__bf16)((float)a + (float)b) + (float)c)

now it becomes

__bf16 d = (__bf16)((float)a + (float)b + (float)c)

The insruction VCVTNEPS2BF16 always take subnormal value as zero. If accuracy is the concern, we cannot leverage native instruction in this case either.

@rmlarsen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang thank you for the clarification. That is indeed a strict improvement.

@krzysz00
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on, isn't this sort of thing only legal in the presence of contract?

That is, I'd argue that if we start with

%t1= fmul bfloat %x, %y
%w = fadd bfloat %t1, %z

which, per promotion, becomes

%x.ext = fpext bfloat %x to float
%y.ext = fpext bfloat %y to float
%t1.long = fmul float %x.ext, %y.ext
%t1 = fptrunc float %t1.long to bfloat
%t1.ext = fpext bfloat %t1 to %t1.ext
%z.ext = fpext bfloat %z to float
%w.long = fadd float %t1.ext, %z.ext
%w = fptrunc float %w.long to bfloat

and then this change rewrites this to

%x.ext = fpext bfloat %x to float
%y.ext = fpext bfloat %y to float
%t1.long = fmul float %x.ext, %y.ext
%z.ext = fpext bfloat %z to float
%w.long = fadd %t1.long, %z.ext
%w = fptrunc float %w.long to bfloat

That is equivalent to the transformation

%w = fma bfloat %x, %y, %z

which is only legal in the presence of contract fastmath.

Therefore, I'm still of the opinion that this merger violates LLVM's semantics.

@phoebewang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's more similar to excess-precision, which independent of contract.

@krzysz00
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, and per the documentation on that
"assignments and explicit casts force the operand to be converted to its formal type, discarding any excess precision".

So, I'm arguing that the LLVM IR

%x1 = fptrunc float %x to bfloat
%x2 = fpext bfloat %x1 to float

, which is equivalent to float x2 = (float)(bfloat)(x); must be protected from this optimization.

That is, there's a difference between the trunc/ext pairs introduced by promotion (ex, if you're doing

%1 = sin(bfloat %x);
%2 = add bfloat %1, %y

with %1 as a float because you cancelled the trunc/ext pairs, that's entirely fine, but the semantics of input IR should be preserved.

Or, alternatively, the documentation should be updated to reflect that you need an arithmetic fence to prevent extended precision ... which is probably too strict.

@phoebewang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice the documentation emphasize it applys only for _Float16.

I think it's reasonable __bf16 owns a more loose rule under the same concept.

@rmlarsen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang I disagree. Also notice that in the original Google implementation of bfloat16, which now resides in the Eigen library, we explicitly disallow implicit casting from float to bfloat16, as it is extremely lossy.

@phoebewang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmlarsen @krzysz00 I think you are right, it's too aggressive to combine explicit fptrunc/fpext.
Created #91420 for it, PTAL.

@rmlarsen
Copy link
Contributor

@rmlarsen rmlarsen commented on 3cf8535 May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebewang I think that looks good. I'm not familiar enough with LLVM internals to give you a formal review. Thank you!

@phoebewang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmlarsen You are welcome :)

Please sign in to comment.