diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index d5b7fe3aa6cb0..1784a8103a66b 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -2216,6 +2216,7 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N, if (Depth >= SelectionDAG::MaxRecursionDepth) return N; + // index: add(x,c) -> index: x, disp + c if (CurDAG->isBaseWithConstantOffset(N)) { auto *AddVal = cast(N.getOperand(1)); uint64_t Offset = (uint64_t)AddVal->getSExtValue() * AM.Scale; @@ -2223,6 +2224,24 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N, return matchIndexRecursively(N.getOperand(0), AM, Depth + 1); } + // index: add(x,x) -> index: x, scale * 2 + if (N.getOpcode() == ISD::ADD && N.getOperand(0) == N.getOperand(1)) { + if (AM.Scale <= 4) { + AM.Scale *= 2; + return matchIndexRecursively(N.getOperand(0), AM, Depth + 1); + } + } + + // index: shl(x,i) -> index: x, scale * (1 << i) + if (N.getOpcode() == X86ISD::VSHLI) { + uint64_t ShiftAmt = N.getConstantOperandVal(1); + uint64_t ScaleAmt = 1ULL << ShiftAmt; + if ((AM.Scale * ScaleAmt) <= 8) { + AM.Scale *= ScaleAmt; + return matchIndexRecursively(N.getOperand(0), AM, Depth + 1); + } + } + // TODO: Handle extensions, shifted masks etc. return N; } @@ -2672,9 +2691,15 @@ bool X86DAGToDAGISel::selectVectorAddr(MemSDNode *Parent, SDValue BasePtr, SDValue &Index, SDValue &Disp, SDValue &Segment) { X86ISelAddressMode AM; - AM.IndexReg = IndexOp; AM.Scale = cast(ScaleOp)->getZExtValue(); + // Attempt to match index patterns, as long as we're not relying on implicit + // sign-extension, which is performed BEFORE scale. + if (IndexOp.getScalarValueSizeInBits() == BasePtr.getScalarValueSizeInBits()) + AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0); + else + AM.IndexReg = IndexOp; + unsigned AddrSpace = Parent->getPointerInfo().getAddrSpace(); if (AddrSpace == X86AS::GS) AM.Segment = CurDAG->getRegister(X86::GS, MVT::i16); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0d2f3b00313a7..9f70d6cedb761 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -52947,43 +52947,10 @@ static SDValue combineTESTP(SDNode *N, SelectionDAG &DAG, } static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { + TargetLowering::DAGCombinerInfo &DCI) { auto *MemOp = cast(N); - SDValue BasePtr = MemOp->getBasePtr(); - SDValue Index = MemOp->getIndex(); - SDValue Scale = MemOp->getScale(); SDValue Mask = MemOp->getMask(); - // Attempt to fold an index scale into the scale value directly. - // For smaller indices, implicit sext is performed BEFORE scale, preventing - // this fold under most circumstances. - // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively? - if ((Index.getOpcode() == X86ISD::VSHLI || - (Index.getOpcode() == ISD::ADD && - Index.getOperand(0) == Index.getOperand(1))) && - isa(Scale) && - BasePtr.getScalarValueSizeInBits() == Index.getScalarValueSizeInBits()) { - unsigned ShiftAmt = - Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1); - uint64_t ScaleAmt = cast(Scale)->getZExtValue(); - uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt); - if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) { - SDValue NewIndex = Index.getOperand(0); - SDValue NewScale = - DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType()); - if (N->getOpcode() == X86ISD::MGATHER) - return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG, - MemOp->getOperand(1), Mask, - MemOp->getBasePtr(), NewIndex, NewScale, - MemOp->getChain(), Subtarget); - if (N->getOpcode() == X86ISD::MSCATTER) - return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG, - MemOp->getOperand(1), Mask, MemOp->getBasePtr(), - NewIndex, NewScale, MemOp->getChain(), Subtarget); - } - } - // With vector masks we only demand the upper bit of the mask. if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -55920,8 +55887,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget); case X86ISD::TESTP: return combineTESTP(N, DAG, DCI, Subtarget); case X86ISD::MGATHER: - case X86ISD::MSCATTER: - return combineX86GatherScatter(N, DAG, DCI, Subtarget); + case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI); case X86ISD::PCMPEQ: