Skip to content

Commit

Permalink
[X86][BF16] Fix 2 crashes with vector broadcast
Browse files Browse the repository at this point in the history
Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D151808

(cherry picked from commit 801dd88)
  • Loading branch information
phoebewang authored and tstellar committed Jun 1, 2023
1 parent 4fd1b86 commit 726af32
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
12 changes: 7 additions & 5 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::FMUL, VT, Expand);
setOperationAction(ISD::FDIV, VT, Expand);
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
}
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
}
Expand All @@ -2207,6 +2208,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::FMUL, MVT::v32bf16, Expand);
setOperationAction(ISD::FDIV, MVT::v32bf16, Expand);
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
}

if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) {
Expand Down Expand Up @@ -18773,11 +18775,11 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
return DAG.getBitcast(VT, DAG.getVectorShuffle(FpVT, DL, V1, V2, Mask));
}

if (VT == MVT::v16f16) {
V1 = DAG.getBitcast(MVT::v16i16, V1);
V2 = DAG.getBitcast(MVT::v16i16, V2);
return DAG.getBitcast(MVT::v16f16,
DAG.getVectorShuffle(MVT::v16i16, DL, V1, V2, Mask));
if (VT == MVT::v16f16 || VT.getVectorElementType() == MVT::bf16) {
MVT IVT = VT.changeVectorElementTypeToInteger();
V1 = DAG.getBitcast(IVT, V1);
V2 = DAG.getBitcast(IVT, V2);
return DAG.getBitcast(VT, DAG.getVectorShuffle(IVT, DL, V1, V2, Mask));
}

switch (VT.SimpleTy) {
Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -12969,6 +12969,27 @@ let Predicates = [HasBF16, HasVLX] in {
(VCVTNEPS2BF16Z256rr VR256X:$src)>;
def : Pat<(v8bf16 (int_x86_vcvtneps2bf16256 (loadv8f32 addr:$src))),
(VCVTNEPS2BF16Z256rm addr:$src)>;

def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)),
(VPBROADCASTWZ128rm addr:$src)>;
def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)),
(VPBROADCASTWZ256rm addr:$src)>;

def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
(VPBROADCASTWZ128rr VR128X:$src)>;
def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
(VPBROADCASTWZ256rr VR128X:$src)>;

// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
}

let Predicates = [HasBF16] in {
def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)),
(VPBROADCASTWZrm addr:$src)>;

def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
(VPBROADCASTWZrr VR128X:$src)>;
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
}

let Constraints = "$src1 = $dst" in {
Expand Down
46 changes: 46 additions & 0 deletions llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,49 @@ entry:
%2 = select <4 x i1> %1, <4 x float> %0, <4 x float> %E
ret <4 x float> %2
}

define <16 x i16> @test_no_vbroadcast1() {
; CHECK-LABEL: test_no_vbroadcast1:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
; CHECK-NEXT: vpbroadcastw %xmm0, %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0xc0]
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
entry:
%0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
%1 = bitcast <8 x bfloat> %0 to <8 x i16>
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <16 x i32> zeroinitializer
ret <16 x i16> %2
}

;; FIXME: This should generate the same output as above, but let's fix the crash first.
define <16 x bfloat> @test_no_vbroadcast2() nounwind {
; X86-LABEL: test_no_vbroadcast2:
; X86: # %bb.0: # %entry
; X86-NEXT: pushl %ebp # encoding: [0x55]
; X86-NEXT: movl %esp, %ebp # encoding: [0x89,0xe5]
; X86-NEXT: andl $-32, %esp # encoding: [0x83,0xe4,0xe0]
; X86-NEXT: subl $64, %esp # encoding: [0x83,0xec,0x40]
; X86-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
; X86-NEXT: vmovaps %xmm0, (%esp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24]
; X86-NEXT: vpbroadcastw (%esp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24]
; X86-NEXT: movl %ebp, %esp # encoding: [0x89,0xec]
; X86-NEXT: popl %ebp # encoding: [0x5d]
; X86-NEXT: retl # encoding: [0xc3]
;
; X64-LABEL: test_no_vbroadcast2:
; X64: # %bb.0: # %entry
; X64-NEXT: pushq %rbp # encoding: [0x55]
; X64-NEXT: movq %rsp, %rbp # encoding: [0x48,0x89,0xe5]
; X64-NEXT: andq $-32, %rsp # encoding: [0x48,0x83,0xe4,0xe0]
; X64-NEXT: subq $64, %rsp # encoding: [0x48,0x83,0xec,0x40]
; X64-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
; X64-NEXT: vmovaps %xmm0, (%rsp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24]
; X64-NEXT: vpbroadcastw (%rsp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24]
; X64-NEXT: movq %rbp, %rsp # encoding: [0x48,0x89,0xec]
; X64-NEXT: popq %rbp # encoding: [0x5d]
; X64-NEXT: retq # encoding: [0xc3]
entry:
%0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
%1 = shufflevector <8 x bfloat> %0, <8 x bfloat> undef, <16 x i32> zeroinitializer
ret <16 x bfloat> %1
}

0 comments on commit 726af32

Please sign in to comment.