Skip to content

Commit

Permalink
[X86][BF16] Share FP16 vector ABI with BF16
Browse files Browse the repository at this point in the history
The ABI of BF16 is identical to FP16 rather than i16.

Fixes #62997

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D151710
  • Loading branch information
phoebewang committed Jun 9, 2023
1 parent 6ebf7cd commit 7634905
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
22 changes: 18 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -417,6 +417,10 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
return Val;
if (PartEVT.isInteger() && ValueVT.isFloatingPoint())
return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);

// Vector/Vector bitcast (e.g. <2 x bfloat> -> <2 x half>).
if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
}

// Promoted vector extract
Expand Down Expand Up @@ -622,29 +626,39 @@ static SDValue widenVectorToPartType(SelectionDAG &DAG, SDValue Val,
return SDValue();

EVT ValueVT = Val.getValueType();
EVT PartEVT = PartVT.getVectorElementType();
EVT ValueEVT = ValueVT.getVectorElementType();
ElementCount PartNumElts = PartVT.getVectorElementCount();
ElementCount ValueNumElts = ValueVT.getVectorElementCount();

// We only support widening vectors with equivalent element types and
// fixed/scalable properties. If a target needs to widen a fixed-length type
// to a scalable one, it should be possible to use INSERT_SUBVECTOR below.
if (ElementCount::isKnownLE(PartNumElts, ValueNumElts) ||
PartNumElts.isScalable() != ValueNumElts.isScalable() ||
PartVT.getVectorElementType() != ValueVT.getVectorElementType())
PartNumElts.isScalable() != ValueNumElts.isScalable())
return SDValue();

// Have a try for bf16 because some targets share its ABI with fp16.
if (ValueEVT == MVT::bf16 && PartEVT == MVT::f16) {
assert(DAG.getTargetLoweringInfo().isTypeLegal(PartVT) &&
"Cannot widen to illegal type");
Val = DAG.getNode(ISD::BITCAST, DL,
ValueVT.changeVectorElementType(MVT::f16), Val);
} else if (PartEVT != ValueEVT) {
return SDValue();
}

// Widening a scalable vector to another scalable vector is done by inserting
// the vector into a larger undef one.
if (PartNumElts.isScalable())
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
Val, DAG.getVectorIdxConstant(0, DL));

EVT ElementVT = PartVT.getVectorElementType();
// Vector widening case, e.g. <2 x float> -> <4 x float>. Shuffle in
// undef elements.
SmallVector<SDValue, 16> Ops;
DAG.ExtractVectorElements(Val, Ops);
SDValue EltUndef = DAG.getUNDEF(ElementVT);
SDValue EltUndef = DAG.getUNDEF(PartEVT);
Ops.append((PartNumElts - ValueNumElts).getFixedValue(), EltUndef);

// FIXME: Use CONCAT for 2x -> 4x.
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -2608,7 +2608,7 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,

if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getRegisterTypeForCallingConv(Context, CC,
VT.changeVectorElementTypeToInteger());
VT.changeVectorElementType(MVT::f16));

return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}
Expand Down Expand Up @@ -2643,7 +2643,7 @@ unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,

if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getNumRegistersForCallingConv(Context, CC,
VT.changeVectorElementTypeToInteger());
VT.changeVectorElementType(MVT::f16));

return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
Expand Down
28 changes: 25 additions & 3 deletions llvm/test/CodeGen/X86/bfloat.ll
Expand Up @@ -317,13 +317,13 @@ define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
; SSE2-NEXT: movq %rdx, %rax
; SSE2-NEXT: shrq $48, %rax
; SSE2-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %r12
; SSE2-NEXT: movq %r12, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rax, (%rsp) # 8-byte Spill
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
; SSE2-NEXT: movq %xmm0, %r14
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
; SSE2-NEXT: movq %xmm1, %r14
; SSE2-NEXT: movq %r14, %rbp
; SSE2-NEXT: shrq $32, %rbp
; SSE2-NEXT: movq %r12, %r15
Expand Down Expand Up @@ -543,3 +543,25 @@ define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
%add = fadd <8 x bfloat> %a, %b
ret <8 x bfloat> %add
}

define <2 x bfloat> @pr62997(bfloat %a, bfloat %b) {
; SSE2-LABEL: pr62997:
; SSE2: # %bb.0:
; SSE2-NEXT: movd %xmm0, %eax
; SSE2-NEXT: movd %xmm1, %ecx
; SSE2-NEXT: pinsrw $0, %ecx, %xmm1
; SSE2-NEXT: pinsrw $0, %eax, %xmm0
; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; SSE2-NEXT: retq
;
; BF16-LABEL: pr62997:
; BF16: # %bb.0:
; BF16-NEXT: vmovd %xmm1, %eax
; BF16-NEXT: vmovd %xmm0, %ecx
; BF16-NEXT: vmovd %ecx, %xmm0
; BF16-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0
; BF16-NEXT: retq
%1 = insertelement <2 x bfloat> undef, bfloat %a, i64 0
%2 = insertelement <2 x bfloat> %1, bfloat %b, i64 1
ret <2 x bfloat> %2
}

0 comments on commit 7634905

Please sign in to comment.