diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index fee2b19794dd59..385cb754731bcf 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44661,13 +44661,33 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS, + SDValue Index, SDValue Base, SDValue Scale, + SelectionDAG &DAG) { + SDLoc DL(GorS); + + if (auto *Gather = dyn_cast(GorS)) { + SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(), + Gather->getMask(), Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast(GorS); + SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(), + Scatter->getMask(), Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); +} + static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(N); auto *GorS = cast(N); - SDValue Chain = GorS->getChain(); SDValue Index = GorS->getIndex(); - SDValue Mask = GorS->getMask(); SDValue Base = GorS->getBasePtr(); SDValue Scale = GorS->getScale(); @@ -44687,21 +44707,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, unsigned NumElts = Index.getValueType().getVectorNumElements(); EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts); Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index); - if (auto *Gather = dyn_cast(GorS)) { - SDValue Ops[] = { Chain, Gather->getPassThru(), - Mask, Base, Index, Scale } ; - return DAG.getMaskedGather(Gather->getVTList(), - Gather->getMemoryVT(), DL, Ops, - Gather->getMemOperand(), - Gather->getIndexType()); - } - auto *Scatter = cast(GorS); - SDValue Ops[] = { Chain, Scatter->getValue(), - Mask, Base, Index, Scale }; - return DAG.getMaskedScatter(Scatter->getVTList(), - Scatter->getMemoryVT(), DL, - Ops, Scatter->getMemOperand(), - Scatter->getIndexType()); + return rebuildGatherScatter(GorS, Index, Base, Scale, DAG); } } @@ -44716,21 +44722,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, unsigned NumElts = Index.getValueType().getVectorNumElements(); EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts); Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index); - if (auto *Gather = dyn_cast(GorS)) { - SDValue Ops[] = { Chain, Gather->getPassThru(), - Mask, Base, Index, Scale } ; - return DAG.getMaskedGather(Gather->getVTList(), - Gather->getMemoryVT(), DL, Ops, - Gather->getMemOperand(), - Gather->getIndexType()); - } - auto *Scatter = cast(GorS); - SDValue Ops[] = { Chain, Scatter->getValue(), - Mask, Base, Index, Scale }; - return DAG.getMaskedScatter(Scatter->getVTList(), - Scatter->getMemoryVT(), DL, - Ops, Scatter->getMemOperand(), - Scatter->getIndexType()); + return rebuildGatherScatter(GorS, Index, Base, Scale, DAG); } } @@ -44743,25 +44735,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT, Index.getValueType().getVectorNumElements()); Index = DAG.getSExtOrTrunc(Index, DL, IndexVT); - if (auto *Gather = dyn_cast(GorS)) { - SDValue Ops[] = { Chain, Gather->getPassThru(), - Mask, Base, Index, Scale } ; - return DAG.getMaskedGather(Gather->getVTList(), - Gather->getMemoryVT(), DL, Ops, - Gather->getMemOperand(), - Gather->getIndexType()); - } - auto *Scatter = cast(GorS); - SDValue Ops[] = { Chain, Scatter->getValue(), - Mask, Base, Index, Scale }; - return DAG.getMaskedScatter(Scatter->getVTList(), - Scatter->getMemoryVT(), DL, - Ops, Scatter->getMemOperand(), - Scatter->getIndexType()); + return rebuildGatherScatter(GorS, Index, Base, Scale, DAG); } } // With vector masks we only demand the upper bit of the mask. + SDValue Mask = GorS->getMask(); if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));