Skip to content

Commit

Permalink
[SVE][CodeGen] Use splice instruction when lowering VECTOR_SPLICE
Browse files Browse the repository at this point in the history
For certain negative indices passed to the VECTOR_SPLICE operation
we can actually directly use the SVE splice instruction by creating
the appropriate predicate. The predicate needs to be constructed in
such a way that all but the last -idx elements are false. We can do
this efficiently using a combination of 'ptrue' (with the appropriate
fixed pattern, e.g. vl1, vl2, etc.) and 'rev'. The advantage of using
these instructions to generate the predicate is they do not set any
flags, unlike the whilelo instruction. This is critical when the splice
operation is in a loop, since we want MachineLICM to hoist the
predicate generation out of the loop.

Differential Revision: https://reviews.llvm.org/D115863
  • Loading branch information
david-arm committed Jan 11, 2022
1 parent 0b5b35f commit 3a272d1
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 260 deletions.
45 changes: 36 additions & 9 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -7796,10 +7796,37 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
SelectionDAG &DAG) const {
EVT Ty = Op.getValueType();
auto Idx = Op.getConstantOperandAPInt(2);
int64_t IdxVal = Idx.getSExtValue();
assert(Ty.isScalableVector() &&
"Only expect scalable vectors for custom lowering of VECTOR_SPLICE");

// We can use the splice instruction for certain index values where we are
// able to efficiently generate the correct predicate. The index will be
// inverted and used directly as the input to the ptrue instruction, i.e.
// -1 -> vl1, -2 -> vl2, etc. The predicate will then be reversed to get the
// splice predicate. However, we can only do this if we can guarantee that
// there are enough elements in the vector, hence we check the index <= min
// number of elements.
Optional<unsigned> PredPattern;
if (Ty.isScalableVector() && IdxVal < 0 &&
(PredPattern = getSVEPredPatternFromNumElements(std::abs(IdxVal))) !=
None) {
SDLoc DL(Op);

// Create a predicate where all but the last -IdxVal elements are false.
EVT PredVT = Ty.changeVectorElementType(MVT::i1);
SDValue Pred = getPTrue(DAG, DL, PredVT, *PredPattern);
Pred = DAG.getNode(ISD::VECTOR_REVERSE, DL, PredVT, Pred);

// Now splice the two inputs together using the predicate.
return DAG.getNode(AArch64ISD::SPLICE, DL, Ty, Pred, Op.getOperand(0),
Op.getOperand(1));
}

// This will select to an EXT instruction, which has a maximum immediate
// value of 255, hence 2048-bits is the maximum value we can lower.
if (Idx.sge(-1) && Idx.slt(2048 / Ty.getVectorElementType().getSizeInBits()))
if (IdxVal >= 0 &&
IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits()))
return Op;

return SDValue();
Expand Down Expand Up @@ -11011,10 +11038,10 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
if (Vec0.isUndef())
return Op;

unsigned int PredPattern =
Optional<unsigned> PredPattern =
getSVEPredPatternFromNumElements(InVT.getVectorNumElements());
auto PredTy = VT.changeVectorElementType(MVT::i1);
SDValue PTrue = getPTrue(DAG, DL, PredTy, PredPattern);
SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern);
SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1);
return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0);
}
Expand Down Expand Up @@ -12319,15 +12346,15 @@ bool AArch64TargetLowering::lowerInterleavedLoad(

Value *PTrue = nullptr;
if (UseScalable) {
unsigned PgPattern =
Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(FVTy->getNumElements());
if (Subtarget->getMinSVEVectorSizeInBits() ==
Subtarget->getMaxSVEVectorSizeInBits() &&
Subtarget->getMinSVEVectorSizeInBits() == DL.getTypeSizeInBits(FVTy))
PgPattern = AArch64SVEPredPattern::all;

auto *PTruePat =
ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), PgPattern);
ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), *PgPattern);
PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
{PTruePat});
}
Expand Down Expand Up @@ -12499,7 +12526,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,

Value *PTrue = nullptr;
if (UseScalable) {
unsigned PgPattern =
Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(SubVecTy->getNumElements());
if (Subtarget->getMinSVEVectorSizeInBits() ==
Subtarget->getMaxSVEVectorSizeInBits() &&
Expand All @@ -12508,7 +12535,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
PgPattern = AArch64SVEPredPattern::all;

auto *PTruePat =
ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), PgPattern);
ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), *PgPattern);
PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
{PTruePat});
}
Expand Down Expand Up @@ -18752,7 +18779,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
"Expected legal fixed length vector!");

unsigned PgPattern =
Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(VT.getVectorNumElements());
assert(PgPattern && "Unexpected element count for SVE predicate");

Expand Down Expand Up @@ -18788,7 +18815,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
break;
}

return getPTrue(DAG, DL, MaskVT, PgPattern);
return getPTrue(DAG, DL, MaskVT, *PgPattern);
}

static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
Expand Up @@ -483,18 +483,20 @@ inline unsigned getNumElementsFromSVEPredPattern(unsigned Pattern) {
}

/// Return specific VL predicate pattern based on the number of elements.
inline unsigned getSVEPredPatternFromNumElements(unsigned MinNumElts) {
inline Optional<unsigned>
getSVEPredPatternFromNumElements(unsigned MinNumElts) {
switch (MinNumElts) {
default:
llvm_unreachable("unexpected element count for SVE predicate");
return None;
case 1:
return AArch64SVEPredPattern::vl1;
case 2:
return AArch64SVEPredPattern::vl2;
case 3:
case 4:
return AArch64SVEPredPattern::vl4;
case 5:
case 6:
case 7:
case 8:
return AArch64SVEPredPattern::vl8;
return MinNumElts;
case 16:
return AArch64SVEPredPattern::vl16;
case 32:
Expand Down

0 comments on commit 3a272d1

Please sign in to comment.