Skip to content

Commit

Permalink
[AArch64] NFC: Move safe predicate casting to a separate function.
Browse files Browse the repository at this point in the history
This patch puts the code to safely bitcast a predicate, and possibly zero
any undefined lanes when doing a widening cast, into one place and merges
the functionality with lowerConvertToSVBool.

This is some cleanup inspired by D128665.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D128926
  • Loading branch information
sdesmalen-arm committed Jul 4, 2022
1 parent f90f0e8 commit bf89d24
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 36 deletions.
88 changes: 52 additions & 36 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -1082,6 +1082,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
}

// FIXME: Move lowering for more nodes here if those are common between
// SVE and SME.
if (Subtarget->hasSVE() || Subtarget->hasSME()) {
for (auto VT :
{MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1})
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
}

if (Subtarget->hasSVE()) {
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
setOperationAction(ISD::BITREVERSE, VT, Custom);
Expand Down Expand Up @@ -1162,7 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::SETCC, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
Expand Down Expand Up @@ -4333,27 +4340,47 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
DAG.getTargetConstant(Pattern, DL, MVT::i32));
}

static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT OutVT = Op.getValueType();
SDValue InOp = Op.getOperand(1);
EVT InVT = InOp.getValueType();
EVT InVT = Op.getValueType();

assert(InVT.getVectorElementType() == MVT::i1 &&
VT.getVectorElementType() == MVT::i1 &&
"Expected a predicate-to-predicate bitcast");
assert(VT.isScalableVector() && isTypeLegal(VT) &&
InVT.isScalableVector() && isTypeLegal(InVT) &&
"Only expect to cast between legal scalable predicate types!");

// Return the operand if the cast isn't changing type,
// i.e. <n x 16 x i1> -> <n x 16 x i1>
if (InVT == OutVT)
return InOp;
// e.g. <n x 16 x i1> -> <n x 16 x i1>
if (InVT == VT)
return Op;

SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);

SDValue Reinterpret =
DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp);
// We only have to zero the lanes if new lanes are being defined, e.g. when
// casting from <vscale x 2 x i1> to <vscale x 16 x i1>. If this is not the
// case (e.g. when casting from <vscale x 16 x i1> -> <vscale x 2 x i1>) then
// we can return here.
if (InVT.bitsGT(VT))
return Reinterpret;

// If the argument converted to an svbool is a ptrue or a comparison, the
// lanes introduced by the widening are zero by construction.
switch (InOp.getOpcode()) {
// Check if the other lanes are already known to be zeroed by
// construction.
switch (Op.getOpcode()) {
default:
// We guarantee i1 splat_vectors to zero the other lanes by
// implementing it with ptrue and possibly a punpklo for nxv1i1.
if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
return Reinterpret;
break;
case AArch64ISD::SETCC_MERGE_ZERO:
return Reinterpret;
case ISD::INTRINSIC_WO_CHAIN:
switch (InOp.getConstantOperandVal(0)) {
switch (Op.getConstantOperandVal(0)) {
default:
break;
case Intrinsic::aarch64_sve_ptrue:
case Intrinsic::aarch64_sve_cmpeq_wide:
case Intrinsic::aarch64_sve_cmpne_wide:
Expand All @@ -4369,15 +4396,10 @@ static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
}
}

// Splat vectors of one will generate ptrue instructions
if (ISD::isConstantSplatVectorAllOnes(InOp.getNode()))
return Reinterpret;

// Otherwise, zero the newly introduced lanes.
SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all);
SDValue MaskReinterpret =
DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask);
return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret);
// Zero the newly introduced lanes.
SDValue Mask = DAG.getConstant(1, DL, InVT);
Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Mask);
return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
}

SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
Expand Down Expand Up @@ -4546,10 +4568,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_dupq_lane:
return LowerDUPQLane(Op, DAG);
case Intrinsic::aarch64_sve_convert_from_svbool:
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(),
Op.getOperand(1));
return getSVEPredicateBitCast(Op.getValueType(), Op.getOperand(1), DAG);
case Intrinsic::aarch64_sve_convert_to_svbool:
return lowerConvertToSVBool(Op, DAG);
return getSVEPredicateBitCast(MVT::nxv16i1, Op.getOperand(1), DAG);
case Intrinsic::aarch64_sve_fneg:
return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
Expand Down Expand Up @@ -21464,22 +21485,17 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT InVT = Op.getValueType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
(void)TLI;

assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
assert(VT.isScalableVector() && isTypeLegal(VT) &&
InVT.isScalableVector() && isTypeLegal(InVT) &&
"Only expect to cast between legal scalable vector types!");
assert((VT.getVectorElementType() == MVT::i1) ==
(InVT.getVectorElementType() == MVT::i1) &&
"Cannot cast between data and predicate scalable vector types!");
assert(VT.getVectorElementType() != MVT::i1 &&
InVT.getVectorElementType() != MVT::i1 &&
"For predicate bitcasts, use getSVEPredicateBitCast");

if (InVT == VT)
return Op;

if (VT.getVectorElementType() == MVT::i1)
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);

EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -1148,8 +1148,13 @@ class AArch64TargetLowering : public TargetLowering {
// These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
// to transition between unpacked and packed types of the same element type,
// with BITCAST used otherwise.
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;

// Returns a safe bitcast between two scalable vector predicates, where
// any newly created lanes from a widening bitcast are defined as zero.
SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;

bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
LLT Ty2) const override;
};
Expand Down

0 comments on commit bf89d24

Please sign in to comment.