Skip to content

Commit

Permalink
[X86][BF16] Lower FP_ROUND for vector types under AVX512BF16
Browse files Browse the repository at this point in the history
Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D158952
  • Loading branch information
phoebewang committed Aug 29, 2023
1 parent 23fef2c commit b667e9c
Show file tree
Hide file tree
Showing 5 changed files with 660 additions and 252 deletions.
14 changes: 12 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2237,8 +2237,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,

if (!Subtarget.useSoftFloat() &&
(Subtarget.hasAVXNECONVERT() || Subtarget.hasBF16())) {
addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass);
addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass);
addRegisterClass(MVT::v8bf16, Subtarget.hasAVX512() ? &X86::VR128XRegClass
: &X86::VR128RegClass);
addRegisterClass(MVT::v16bf16, Subtarget.hasAVX512() ? &X86::VR256XRegClass
: &X86::VR256RegClass);
// We set the type action of bf16 to TypeSoftPromoteHalf, but we don't
// provide the method to promote BUILD_VECTOR and INSERT_VECTOR_ELT.
// Set the operation action Custom to do the customization later.
Expand All @@ -2253,6 +2255,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
}
setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
}

Expand All @@ -2264,6 +2267,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::FMUL, MVT::v32bf16, Expand);
setOperationAction(ISD::FDIV, MVT::v32bf16, Expand);
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
}

Expand Down Expand Up @@ -21278,6 +21282,12 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
return Res;
}

if (VT.getScalarType() == MVT::bf16) {
if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
return Op;
return SDValue();
}

if (VT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) {
if (!Subtarget.hasF16C() || SVT.getScalarType() != MVT::f32)
return SDValue();
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -12976,6 +12976,11 @@ let Predicates = [HasBF16, HasVLX] in {
def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
(VPBROADCASTWZ256rr VR128X:$src)>;

def : Pat<(v8bf16 (X86vfpround (v8f32 VR256X:$src))),
(VCVTNEPS2BF16Z256rr VR256X:$src)>;
def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),
(VCVTNEPS2BF16Z256rm addr:$src)>;

// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
}

Expand All @@ -12985,6 +12990,11 @@ let Predicates = [HasBF16] in {

def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
(VPBROADCASTWZrr VR128X:$src)>;

def : Pat<(v16bf16 (X86vfpround (v16f32 VR512:$src))),
(VCVTNEPS2BF16Zrr VR512:$src)>;
def : Pat<(v16bf16 (X86vfpround (loadv16f32 addr:$src))),
(VCVTNEPS2BF16Zrm addr:$src)>;
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
}

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/X86/X86InstrSSE.td
Original file line number Diff line number Diff line change
Expand Up @@ -8289,6 +8289,11 @@ let Predicates = [HasAVXNECONVERT] in {
f256mem>, T8PS;
let checkVEXPredicate = 1 in
defm VCVTNEPS2BF16 : VCVTNEPS2BF16_BASE, VEX, T8XS, ExplicitVEXPrefix;

def : Pat<(v8bf16 (X86vfpround (v8f32 VR256:$src))),
(VCVTNEPS2BF16Yrr VR256:$src)>;
def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),
(VCVTNEPS2BF16Yrm addr:$src)>;
}

def : InstAlias<"vcvtneps2bf16x\t{$src, $dst|$dst, $src}",
Expand Down
2 changes: 0 additions & 2 deletions llvm/test/CodeGen/X86/avxneconvert-intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ define <8 x bfloat> @test_int_x86_vcvtneps2bf16128(<4 x float> %A) {
; CHECK-LABEL: test_int_x86_vcvtneps2bf16128:
; CHECK: # %bb.0:
; CHECK-NEXT: {vex} vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0xc4,0xe2,0x7a,0x72,0xc0]
; CHECK-NEXT: # kill: def $xmm1 killed $xmm0
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
%ret = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> %A)
ret <8 x bfloat> %ret
Expand All @@ -209,7 +208,6 @@ define <8 x bfloat> @test_int_x86_vcvtneps2bf16256(<8 x float> %A) {
; CHECK-LABEL: test_int_x86_vcvtneps2bf16256:
; CHECK: # %bb.0:
; CHECK-NEXT: {vex} vcvtneps2bf16 %ymm0, %xmm0 # encoding: [0xc4,0xe2,0x7e,0x72,0xc0]
; CHECK-NEXT: # kill: def $xmm1 killed $xmm0
; CHECK-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
%ret = call <8 x bfloat> @llvm.x86.vcvtneps2bf16256(<8 x float> %A)
Expand Down
Loading

0 comments on commit b667e9c

Please sign in to comment.