Skip to content

Commit

Permalink
[X86] Add 'getSplitVectorSrc' helper to determine if subvectors all c…
Browse files Browse the repository at this point in the history
…ome from the same source

Helps determine if the subvector ops come from the same larger vector and match the lower/upper extractions
  • Loading branch information
RKSimon committed Jan 26, 2022
1 parent de8867a commit 99ae5c1
Showing 1 changed file with 65 additions and 49 deletions.
114 changes: 65 additions & 49 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -6146,6 +6146,29 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
return DAG.getBitcast(VT, Vec);
}

// Helper to determine if the ops are all the extracted subvectors come from a
// single source. If we allow commute they don't have to be in order (Lo/Hi).
static SDValue getSplitVectorSrc(SDValue LHS, SDValue RHS, bool AllowCommute) {
if (LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
LHS.getValueType() != RHS.getValueType() ||
LHS.getOperand(0) != RHS.getOperand(0))
return SDValue();

SDValue Src = LHS.getOperand(0);
if (Src.getValueSizeInBits() != (LHS.getValueSizeInBits() * 2))
return SDValue();

unsigned NumElts = LHS.getValueType().getVectorNumElements();
if ((LHS.getConstantOperandAPInt(1) == 0 &&
RHS.getConstantOperandAPInt(1) == NumElts) ||
(AllowCommute && RHS.getConstantOperandAPInt(1) == 0 &&
LHS.getConstantOperandAPInt(1) == NumElts))
return Src;

return SDValue();
}

static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
const SDLoc &dl, unsigned vectorWidth) {
EVT VT = Vec.getValueType();
Expand Down Expand Up @@ -44512,30 +44535,28 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
// PMOVMSKB(PACKSSBW(LO(X), HI(X)))
// -> PMOVMSKB(BITCAST_v32i8(X)) & 0xAAAAAAAA.
if (CmpBits >= 16 && Subtarget.hasInt256() &&
VecOp0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
VecOp1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
VecOp0.getOperand(0) == VecOp1.getOperand(0) &&
VecOp0.getConstantOperandAPInt(1) == 0 &&
VecOp1.getConstantOperandAPInt(1) == 8 &&
(IsAnyOf || (SignExt0 && SignExt1))) {
SDLoc DL(EFLAGS);
SDValue Result = peekThroughBitcasts(VecOp0.getOperand(0));
if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) {
SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(),
Result.getOperand(0), Result.getOperand(1));
V = DAG.getBitcast(MVT::v4i64, V);
return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
}
Result = DAG.getBitcast(MVT::v32i8, Result);
Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
if (!SignExt0 || !SignExt1) {
assert(IsAnyOf && "Only perform v16i16 signmasks for any_of patterns");
Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
if (SDValue Src = getSplitVectorSrc(VecOp0, VecOp1, true)) {
SDLoc DL(EFLAGS);
SDValue Result = peekThroughBitcasts(Src);
if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) {
SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(),
Result.getOperand(0), Result.getOperand(1));
V = DAG.getBitcast(MVT::v4i64, V);
return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
}
Result = DAG.getBitcast(MVT::v32i8, Result);
Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
if (!SignExt0 || !SignExt1) {
assert(IsAnyOf &&
"Only perform v16i16 signmasks for any_of patterns");
Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
}
return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
DAG.getConstant(CmpMask, DL, MVT::i32));
}
return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
DAG.getConstant(CmpMask, DL, MVT::i32));
}
}

Expand Down Expand Up @@ -45582,33 +45603,28 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
// truncation trees that help us avoid lane crossing shuffles.
// TODO: There's a lot more we can do for PACK/HADD style shuffle combines.
// TODO: We don't handle vXf64 shuffles yet.
if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32 &&
BC0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
BC1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
BC0.getOperand(0) == BC1.getOperand(0) &&
BC0.getOperand(0).getValueType().is256BitVector() &&
BC0.getConstantOperandAPInt(1) == 0 &&
BC1.getConstantOperandAPInt(1) ==
BC0.getValueType().getVectorNumElements()) {
SmallVector<SDValue> ShuffleOps;
SmallVector<int> ShuffleMask, ScaledMask;
SDValue Vec = peekThroughBitcasts(BC0.getOperand(0));
if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
// To keep the HOP LHS/RHS coherency, we must be able to scale the unary
// shuffle to a v4X64 width - we can probably relax this in the future.
if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
ShuffleOps[0].getValueType().is256BitVector() &&
scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
SDValue Lo, Hi;
MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
Lo = DAG.getBitcast(SrcVT, Lo);
Hi = DAG.getBitcast(SrcVT, Hi);
SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
Res = DAG.getBitcast(ShufVT, Res);
Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
return DAG.getBitcast(VT, Res);
if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) {
if (SDValue BCSrc = getSplitVectorSrc(BC0, BC1, false)) {
SmallVector<SDValue> ShuffleOps;
SmallVector<int> ShuffleMask, ScaledMask;
SDValue Vec = peekThroughBitcasts(BCSrc);
if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
// To keep the HOP LHS/RHS coherency, we must be able to scale the unary
// shuffle to a v4X64 width - we can probably relax this in the future.
if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
ShuffleOps[0].getValueType().is256BitVector() &&
scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
SDValue Lo, Hi;
MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
Lo = DAG.getBitcast(SrcVT, Lo);
Hi = DAG.getBitcast(SrcVT, Hi);
SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
Res = DAG.getBitcast(ShufVT, Res);
Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
return DAG.getBitcast(VT, Res);
}
}
}
}
Expand Down

0 comments on commit 99ae5c1

Please sign in to comment.