Skip to content

Commit

Permalink
[AArch64] Add some missing strict FP vector lowering
Browse files Browse the repository at this point in the history
Also add a test for the codegen of strict FP vector operations so
these changes get tested.

Differential Revision: https://reviews.llvm.org/D117795
  • Loading branch information
john-brawn-arm committed Feb 17, 2022
1 parent 9071393 commit 8e17c96
Show file tree
Hide file tree
Showing 2 changed files with 947 additions and 2 deletions.
63 changes: 61 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -3404,7 +3404,8 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
// Any additional optimization in this function should be recorded
// in the cost tables.
EVT InVT = Op.getOperand(0).getValueType();
bool IsStrict = Op->isStrictFPOpcode();
EVT InVT = Op.getOperand(IsStrict ? 1 : 0).getValueType();
EVT VT = Op.getValueType();

if (VT.isScalableVector()) {
Expand All @@ -3424,6 +3425,12 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
!Subtarget->hasFullFP16()) {
MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
SDLoc dl(Op);
if (IsStrict) {
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
}
return DAG.getNode(
Op.getOpcode(), dl, Op.getValueType(),
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
Expand All @@ -3433,6 +3440,13 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
uint64_t InVTSize = InVT.getFixedSizeInBits();
if (VTSize < InVTSize) {
SDLoc dl(Op);
if (IsStrict) {
InVT = InVT.changeVectorElementTypeToInteger();
SDValue Cv = DAG.getNode(Op.getOpcode(), dl, {InVT, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
return DAG.getMergeValues({Trunc, Cv.getValue(1)}, dl);
}
SDValue Cv =
DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(),
Op.getOperand(0));
Expand All @@ -3444,10 +3458,32 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
MVT ExtVT =
MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
VT.getVectorNumElements());
if (IsStrict) {
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl,
{ExtVT, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
}
SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, ExtVT, Op.getOperand(0));
return DAG.getNode(Op.getOpcode(), dl, VT, Ext);
}

// Use a scalar operation for conversions between single-element vectors of
// the same size.
if (NumElts == 1) {
SDLoc dl(Op);
SDValue Extract = DAG.getNode(
ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
Op.getOperand(IsStrict ? 1 : 0), DAG.getConstant(0, dl, MVT::i64));
EVT ScalarVT = VT.getScalarType();
SDValue ScalarCvt;
if (IsStrict)
return DAG.getNode(Op.getOpcode(), dl, {ScalarVT, MVT::Other},
{Op.getOperand(0), Extract});
return DAG.getNode(Op.getOpcode(), dl, ScalarVT, Extract);
}

// Type changing conversions are illegal.
return Op;
}
Expand Down Expand Up @@ -3610,9 +3646,10 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
// Any additional optimization in this function should be recorded
// in the cost tables.
bool IsStrict = Op->isStrictFPOpcode();
EVT VT = Op.getValueType();
SDLoc dl(Op);
SDValue In = Op.getOperand(0);
SDValue In = Op.getOperand(IsStrict ? 1 : 0);
EVT InVT = In.getValueType();
unsigned Opc = Op.getOpcode();
bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP;
Expand Down Expand Up @@ -3640,6 +3677,13 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
MVT CastVT =
MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()),
InVT.getVectorNumElements());
if (IsStrict) {
In = DAG.getNode(Opc, dl, {CastVT, MVT::Other},
{Op.getOperand(0), In});
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
{In.getValue(1), In.getValue(0), DAG.getIntPtrConstant(0, dl)});
}
In = DAG.getNode(Opc, dl, CastVT, In);
return DAG.getNode(ISD::FP_ROUND, dl, VT, In, DAG.getIntPtrConstant(0, dl));
}
Expand All @@ -3648,9 +3692,24 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
EVT CastVT = VT.changeVectorElementTypeToInteger();
In = DAG.getNode(CastOpc, dl, CastVT, In);
if (IsStrict)
return DAG.getNode(Opc, dl, {VT, MVT::Other}, {Op.getOperand(0), In});
return DAG.getNode(Opc, dl, VT, In);
}

// Use a scalar operation for conversions between single-element vectors of
// the same size.
if (VT.getVectorNumElements() == 1) {
SDValue Extract = DAG.getNode(
ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
In, DAG.getConstant(0, dl, MVT::i64));
EVT ScalarVT = VT.getScalarType();
if (IsStrict)
return DAG.getNode(Op.getOpcode(), dl, {ScalarVT, MVT::Other},
{Op.getOperand(0), Extract});
return DAG.getNode(Op.getOpcode(), dl, ScalarVT, Extract);
}

return Op;
}

Expand Down

0 comments on commit 8e17c96

Please sign in to comment.