Skip to content

Commit

Permalink
[X86] Move gather/scatter index shl(x,c) -> index:x, scale:c fold int…
Browse files Browse the repository at this point in the history
…o X86DAGToDAGISel::matchIndexRecursively
  • Loading branch information
RKSimon committed Aug 22, 2023
1 parent 8c21544 commit e8900df
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 37 deletions.
27 changes: 26 additions & 1 deletion llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2216,13 +2216,32 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
if (Depth >= SelectionDAG::MaxRecursionDepth)
return N;

// index: add(x,c) -> index: x, disp + c
if (CurDAG->isBaseWithConstantOffset(N)) {
auto *AddVal = cast<ConstantSDNode>(N.getOperand(1));
uint64_t Offset = (uint64_t)AddVal->getSExtValue() * AM.Scale;
if (!foldOffsetIntoAddress(Offset, AM))
return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
}

// index: add(x,x) -> index: x, scale * 2
if (N.getOpcode() == ISD::ADD && N.getOperand(0) == N.getOperand(1)) {
if (AM.Scale <= 4) {
AM.Scale *= 2;
return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
}
}

// index: shl(x,i) -> index: x, scale * (1 << i)
if (N.getOpcode() == X86ISD::VSHLI) {
uint64_t ShiftAmt = N.getConstantOperandVal(1);
uint64_t ScaleAmt = 1ULL << ShiftAmt;
if ((AM.Scale * ScaleAmt) <= 8) {
AM.Scale *= ScaleAmt;
return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
}
}

// TODO: Handle extensions, shifted masks etc.
return N;
}
Expand Down Expand Up @@ -2672,9 +2691,15 @@ bool X86DAGToDAGISel::selectVectorAddr(MemSDNode *Parent, SDValue BasePtr,
SDValue &Index, SDValue &Disp,
SDValue &Segment) {
X86ISelAddressMode AM;
AM.IndexReg = IndexOp;
AM.Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue();

// Attempt to match index patterns, as long as we're not relying on implicit
// sign-extension, which is performed BEFORE scale.
if (IndexOp.getScalarValueSizeInBits() == BasePtr.getScalarValueSizeInBits())
AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0);
else
AM.IndexReg = IndexOp;

unsigned AddrSpace = Parent->getPointerInfo().getAddrSpace();
if (AddrSpace == X86AS::GS)
AM.Segment = CurDAG->getRegister(X86::GS, MVT::i16);
Expand Down
38 changes: 2 additions & 36 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52947,43 +52947,10 @@ static SDValue combineTESTP(SDNode *N, SelectionDAG &DAG,
}

static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
TargetLowering::DAGCombinerInfo &DCI) {
auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
SDValue BasePtr = MemOp->getBasePtr();
SDValue Index = MemOp->getIndex();
SDValue Scale = MemOp->getScale();
SDValue Mask = MemOp->getMask();

// Attempt to fold an index scale into the scale value directly.
// For smaller indices, implicit sext is performed BEFORE scale, preventing
// this fold under most circumstances.
// TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively?
if ((Index.getOpcode() == X86ISD::VSHLI ||
(Index.getOpcode() == ISD::ADD &&
Index.getOperand(0) == Index.getOperand(1))) &&
isa<ConstantSDNode>(Scale) &&
BasePtr.getScalarValueSizeInBits() == Index.getScalarValueSizeInBits()) {
unsigned ShiftAmt =
Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1);
uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt);
if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
SDValue NewIndex = Index.getOperand(0);
SDValue NewScale =
DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType());
if (N->getOpcode() == X86ISD::MGATHER)
return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG,
MemOp->getOperand(1), Mask,
MemOp->getBasePtr(), NewIndex, NewScale,
MemOp->getChain(), Subtarget);
if (N->getOpcode() == X86ISD::MSCATTER)
return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG,
MemOp->getOperand(1), Mask, MemOp->getBasePtr(),
NewIndex, NewScale, MemOp->getChain(), Subtarget);
}
}

// With vector masks we only demand the upper bit of the mask.
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
Expand Down Expand Up @@ -55920,8 +55887,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget);
case X86ISD::TESTP: return combineTESTP(N, DAG, DCI, Subtarget);
case X86ISD::MGATHER:
case X86ISD::MSCATTER:
return combineX86GatherScatter(N, DAG, DCI, Subtarget);
case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:
Expand Down

0 comments on commit e8900df

Please sign in to comment.