Skip to content

Commit

Permalink
[X86][BF16] Customize INSERT_VECTOR_ELT for bf16 when feature BF16 is on
Browse files Browse the repository at this point in the history
Fixes root cause of #63017.
The reason is similar to BUILD_VECTOR. We have legal vector type but
still soft promote for scalar type. So we need to customize these scalar
to vector nodes.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D155961
  • Loading branch information
phoebewang committed Jul 22, 2023
1 parent 9b2dfff commit 04527f1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
13 changes: 11 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2276,9 +2276,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass);
addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass);
// We set the type action of bf16 to TypeSoftPromoteHalf, but we don't
// provide the method to promote BUILD_VECTOR. Set the operation action
// Custom to do the customization later.
// provide the method to promote BUILD_VECTOR and INSERT_VECTOR_ELT.
// Set the operation action Custom to do the customization later.
setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom);
for (auto VT : {MVT::v8bf16, MVT::v16bf16}) {
setF16Action(VT, Expand);
setOperationAction(ISD::FADD, VT, Expand);
Expand Down Expand Up @@ -20751,6 +20752,14 @@ SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
SDValue N2 = Op.getOperand(2);
auto *N2C = dyn_cast<ConstantSDNode>(N2);

if (EltVT == MVT::bf16) {
MVT IVT = VT.changeVectorElementTypeToInteger();
SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, IVT,
DAG.getBitcast(IVT, N0),
DAG.getBitcast(MVT::i16, N1), N2);
return DAG.getBitcast(VT, Res);
}

if (!N2C) {
// Variable insertion indices, usually we're better off spilling to stack,
// but AVX512 can use a variable compare+select by comparing against all
Expand Down
25 changes: 25 additions & 0 deletions llvm/test/CodeGen/X86/bfloat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1158,4 +1158,29 @@ define <32 x bfloat> @pr63017_2() nounwind {
ret <32 x bfloat> %1
}

define <32 x bfloat> @pr62997_3(<32 x bfloat> %0, bfloat %1) {
; SSE2-LABEL: pr62997_3:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm0, %rax
; SSE2-NEXT: movabsq $-4294967296, %rcx # imm = 0xFFFFFFFF00000000
; SSE2-NEXT: andq %rax, %rcx
; SSE2-NEXT: movzwl %ax, %eax
; SSE2-NEXT: movd %xmm4, %edx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: orl %eax, %edx
; SSE2-NEXT: orq %rcx, %rdx
; SSE2-NEXT: movq %rdx, %xmm4
; SSE2-NEXT: movsd {{.*#+}} xmm0 = xmm4[0],xmm0[1]
; SSE2-NEXT: retq
;
; BF16-LABEL: pr62997_3:
; BF16: # %bb.0:
; BF16-NEXT: vmovd %xmm1, %eax
; BF16-NEXT: vpinsrw $1, %eax, %xmm0, %xmm1
; BF16-NEXT: vinserti32x4 $0, %xmm1, %zmm0, %zmm0
; BF16-NEXT: retq
%3 = insertelement <32 x bfloat> %0, bfloat %1, i64 1
ret <32 x bfloat> %3
}

declare <32 x bfloat> @llvm.masked.load.v32bf16.p0(ptr, i32, <32 x i1>, <32 x bfloat>)

0 comments on commit 04527f1

Please sign in to comment.