Skip to content

Commit

Permalink
[X86] Recognize CVTPH2PS from STRICT_FP_EXTEND
Browse files Browse the repository at this point in the history
This should avoid scalarizing the cvtph2ps intrinsics with D75162

Differential Revision: https://reviews.llvm.org/D75304
  • Loading branch information
topperc committed Feb 28, 2020
1 parent a57f1a5 commit c0d0e6b
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 4 deletions.
53 changes: 49 additions & 4 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -2062,6 +2062,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setTargetDAGCombine(ISD::MGATHER);
setTargetDAGCombine(ISD::FP16_TO_FP);
setTargetDAGCombine(ISD::FP_EXTEND);
setTargetDAGCombine(ISD::STRICT_FP_EXTEND);
setTargetDAGCombine(ISD::FP_ROUND);

computeRegisterProperties(Subtarget.getRegisterInfo());
Expand Down Expand Up @@ -28985,6 +28986,23 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(Res);
return;
}
case X86ISD::STRICT_CVTPH2PS: {
EVT VT = N->getValueType(0);
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 1);
EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
Lo = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {LoVT, MVT::Other},
{N->getOperand(0), Lo});
Hi = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {HiVT, MVT::Other},
{N->getOperand(0), Hi});
SDValue Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
Lo.getValue(1), Hi.getValue(1));
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
Results.push_back(Res);
Results.push_back(Chain);
return;
}
case ISD::CTPOP: {
assert(N->getValueType(0) == MVT::i64 && "Unexpected VT!");
// Use a v2i64 if possible.
Expand Down Expand Up @@ -29555,7 +29573,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(V.getValue(1));
return;
}
case ISD::FP_EXTEND: {
case ISD::FP_EXTEND:
case ISD::STRICT_FP_EXTEND: {
// Right now, only MVT::v2f32 has OperationAction for FP_EXTEND.
// No other ValueType for FP_EXTEND should reach this point.
assert(N->getValueType(0) == MVT::v2f32 &&
Expand Down Expand Up @@ -43810,7 +43829,8 @@ static SDValue combineBT(SDNode *N, SelectionDAG &DAG,

static SDValue combineCVTPH2PS(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDValue Src = N->getOperand(0);
bool IsStrict = N->getOpcode() == X86ISD::STRICT_CVTPH2PS;
SDValue Src = N->getOperand(IsStrict ? 1 : 0);

if (N->getValueType(0) == MVT::v4f32 && Src.getValueType() == MVT::v8i16) {
APInt KnownUndef, KnownZero;
Expand All @@ -43822,6 +43842,11 @@ static SDValue combineCVTPH2PS(SDNode *N, SelectionDAG &DAG,
return SDValue(N, 0);
}

// FIXME: Shrink vector loads.
if (IsStrict)
return SDValue();

// Convert a full vector load into vzload when not all bits are needed.
if (ISD::isNormalLoad(Src.getNode()) && Src.hasOneUse()) {
LoadSDNode *LN = cast<LoadSDNode>(N->getOperand(0));
// Unless the load is volatile or atomic.
Expand Down Expand Up @@ -46721,8 +46746,9 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
return SDValue();

bool IsStrict = N->isStrictFPOpcode();
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(0);
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
EVT SrcVT = Src.getValueType();

if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
Expand Down Expand Up @@ -46755,14 +46781,31 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
// Destination is vXf32 with at least 4 elements.
EVT CvtVT = EVT::getVectorVT(*DAG.getContext(), MVT::f32,
std::max(4U, NumElts));
SDValue Cvt = DAG.getNode(X86ISD::CVTPH2PS, dl, CvtVT, Src);
SDValue Cvt, Chain;
if (IsStrict) {
Cvt = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {CvtVT, MVT::Other},
{N->getOperand(0), Src});
Chain = Cvt.getValue(1);
} else {
Cvt = DAG.getNode(X86ISD::CVTPH2PS, dl, CvtVT, Src);
}

if (NumElts < 4) {
assert(NumElts == 2 && "Unexpected size");
Cvt = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2f32, Cvt,
DAG.getIntPtrConstant(0, dl));
}

if (IsStrict) {
// Extend to the original VT if necessary.
if (Cvt.getValueType() != VT) {
Cvt = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {VT, MVT::Other},
{Chain, Cvt});
Chain = Cvt.getValue(1);
}
return DAG.getMergeValues({Cvt, Chain}, dl);
}

// Extend to the original VT if necessary.
return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt);
}
Expand Down Expand Up @@ -46876,6 +46919,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::CVTP2UI:
case X86ISD::CVTTP2SI:
case X86ISD::CVTTP2UI: return combineCVTP2I_CVTTP2I(N, DAG, DCI);
case X86ISD::STRICT_CVTPH2PS:
case X86ISD::CVTPH2PS: return combineCVTPH2PS(N, DAG, DCI);
case X86ISD::BT: return combineBT(N, DAG, DCI);
case ISD::ANY_EXTEND:
Expand Down Expand Up @@ -46962,6 +47006,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::KSHIFTL:
case X86ISD::KSHIFTR: return combineKSHIFT(N, DAG, DCI);
case ISD::FP16_TO_FP: return combineFP16_TO_FP(N, DAG, Subtarget);
case ISD::STRICT_FP_EXTEND:
case ISD::FP_EXTEND: return combineFP_EXTEND(N, DAG, Subtarget);
case ISD::FP_ROUND: return combineFP_ROUND(N, DAG, Subtarget);
}
Expand Down
135 changes: 135 additions & 0 deletions llvm/test/CodeGen/X86/vector-half-conversions.ll
Expand Up @@ -78,6 +78,65 @@ define <16 x float> @cvt_16i16_to_16f32(<16 x i16> %a0) nounwind {
ret <16 x float> %2
}

define <2 x float> @cvt_2i16_to_2f32_constrained(<2 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_2i16_to_2f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: retq
%1 = bitcast <2 x i16> %a0 to <2 x half>
%2 = call <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half> %1, metadata !"fpexcept.strict") strictfp
ret <2 x float> %2
}
declare <2 x float> @llvm.experimental.constrained.fpext.v2f32.v2f16(<2 x half>, metadata) strictfp

define <4 x float> @cvt_4i16_to_4f32_constrained(<4 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_4i16_to_4f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: retq
%1 = bitcast <4 x i16> %a0 to <4 x half>
%2 = call <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half> %1, metadata !"fpexcept.strict") strictfp
ret <4 x float> %2
}
declare <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half>, metadata) strictfp

define <8 x float> @cvt_8i16_to_8f32_constrained(<8 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_8i16_to_8f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps %xmm0, %ymm0
; ALL-NEXT: retq
%1 = bitcast <8 x i16> %a0 to <8 x half>
%2 = call <8 x float> @llvm.experimental.constrained.fpext.v8f32.v8f16(<8 x half> %1, metadata !"fpexcept.strict") strictfp
ret <8 x float> %2
}
declare <8 x float> @llvm.experimental.constrained.fpext.v8f32.v8f16(<8 x half>, metadata) strictfp

define <16 x float> @cvt_16i16_to_16f32_constrained(<16 x i16> %a0) nounwind strictfp {
; AVX1-LABEL: cvt_16i16_to_16f32_constrained:
; AVX1: # %bb.0:
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX1-NEXT: vcvtph2ps %xmm1, %ymm1
; AVX1-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: cvt_16i16_to_16f32_constrained:
; AVX2: # %bb.0:
; AVX2-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX2-NEXT: vcvtph2ps %xmm1, %ymm1
; AVX2-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cvt_16i16_to_16f32_constrained:
; AVX512: # %bb.0:
; AVX512-NEXT: vcvtph2ps %ymm0, %zmm0
; AVX512-NEXT: retq
%1 = bitcast <16 x i16> %a0 to <16 x half>
%2 = call <16 x float> @llvm.experimental.constrained.fpext.v16f32.v16f16(<16 x half> %1, metadata !"fpexcept.strict") strictfp
ret <16 x float> %2
}
declare <16 x float> @llvm.experimental.constrained.fpext.v16f32.v16f16(<16 x half>, metadata) strictfp

;
; Half to Float (Load)
;
Expand Down Expand Up @@ -152,6 +211,29 @@ define <16 x float> @load_cvt_16i16_to_16f32(<16 x i16>* %a0) nounwind {
ret <16 x float> %3
}

define <4 x float> @load_cvt_4i16_to_4f32_constrained(<4 x i16>* %a0) nounwind strictfp {
; ALL-LABEL: load_cvt_4i16_to_4f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps (%rdi), %xmm0
; ALL-NEXT: retq
%1 = load <4 x i16>, <4 x i16>* %a0
%2 = bitcast <4 x i16> %1 to <4 x half>
%3 = call <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half> %2, metadata !"fpexcept.strict") strictfp
ret <4 x float> %3
}

define <4 x float> @load_cvt_8i16_to_4f32_constrained(<8 x i16>* %a0) nounwind {
; ALL-LABEL: load_cvt_8i16_to_4f32_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps (%rdi), %xmm0
; ALL-NEXT: retq
%1 = load <8 x i16>, <8 x i16>* %a0
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
%3 = bitcast <4 x i16> %2 to <4 x half>
%4 = call <4 x float> @llvm.experimental.constrained.fpext.v4f32.v4f16(<4 x half> %3, metadata !"fpexcept.strict") strictfp
ret <4 x float> %4
}

;
; Half to Double
;
Expand Down Expand Up @@ -244,6 +326,59 @@ define <8 x double> @cvt_8i16_to_8f64(<8 x i16> %a0) nounwind {
ret <8 x double> %2
}

define <2 x double> @cvt_2i16_to_2f64_constrained(<2 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_2i16_to_2f64_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vpmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: vcvtps2pd %xmm0, %xmm0
; ALL-NEXT: retq
%1 = bitcast <2 x i16> %a0 to <2 x half>
%2 = call <2 x double> @llvm.experimental.constrained.fpext.v2f64.v2f16(<2 x half> %1, metadata !"fpexcept.strict") strictfp
ret <2 x double> %2
}
declare <2 x double> @llvm.experimental.constrained.fpext.v2f64.v2f16(<2 x half>, metadata) strictfp

define <4 x double> @cvt_4i16_to_4f64_constrained(<4 x i16> %a0) nounwind strictfp {
; ALL-LABEL: cvt_4i16_to_4f64_constrained:
; ALL: # %bb.0:
; ALL-NEXT: vcvtph2ps %xmm0, %xmm0
; ALL-NEXT: vcvtps2pd %xmm0, %ymm0
; ALL-NEXT: retq
%1 = bitcast <4 x i16> %a0 to <4 x half>
%2 = call <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f16(<4 x half> %1, metadata !"fpexcept.strict") strictfp
ret <4 x double> %2
}
declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f16(<4 x half>, metadata) strictfp

define <8 x double> @cvt_8i16_to_8f64_constrained(<8 x i16> %a0) nounwind strictfp {
; AVX1-LABEL: cvt_8i16_to_8f64_constrained:
; AVX1: # %bb.0:
; AVX1-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX1-NEXT: vcvtps2pd %xmm1, %ymm1
; AVX1-NEXT: vcvtps2pd %xmm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: cvt_8i16_to_8f64_constrained:
; AVX2: # %bb.0:
; AVX2-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX2-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX2-NEXT: vcvtps2pd %xmm1, %ymm1
; AVX2-NEXT: vcvtps2pd %xmm0, %ymm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: cvt_8i16_to_8f64_constrained:
; AVX512: # %bb.0:
; AVX512-NEXT: vcvtph2ps %xmm0, %ymm0
; AVX512-NEXT: vcvtps2pd %ymm0, %zmm0
; AVX512-NEXT: retq
%1 = bitcast <8 x i16> %a0 to <8 x half>
%2 = call <8 x double> @llvm.experimental.constrained.fpext.v8f64.v8f16(<8 x half> %1, metadata !"fpexcept.strict") strictfp
ret <8 x double> %2
}
declare <8 x double> @llvm.experimental.constrained.fpext.v8f64.v8f16(<8 x half>, metadata) strictfp

;
; Half to Double (Load)
;
Expand Down

0 comments on commit c0d0e6b

Please sign in to comment.