diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 6ef30663bf3ce..18a45c6799bac 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -319,20 +319,20 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { return nullptr; } -/// Find elements of V demanded by UserInstr. -static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { +/// Find elements of V demanded by UserInstr. If returns false, we were not able +/// to determine all elements. +static bool findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr, + APInt &UnionUsedElts) { unsigned VWidth = cast(V->getType())->getNumElements(); - // Conservatively assume that all elements are needed. - APInt UsedElts(APInt::getAllOnes(VWidth)); - switch (UserInstr->getOpcode()) { case Instruction::ExtractElement: { ExtractElementInst *EEI = cast(UserInstr); assert(EEI->getVectorOperand() == V); ConstantInt *EEIIndexC = dyn_cast(EEI->getIndexOperand()); if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { - UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + UnionUsedElts.setBit(EEIIndexC->getZExtValue()); + return true; } break; } @@ -341,23 +341,23 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { unsigned MaskNumElts = cast(UserInstr->getType())->getNumElements(); - UsedElts = APInt(VWidth, 0); - for (unsigned i = 0; i < MaskNumElts; i++) { - unsigned MaskVal = Shuffle->getMaskValue(i); + for (auto I : llvm::seq(MaskNumElts)) { + unsigned MaskVal = Shuffle->getMaskValue(I); if (MaskVal == -1u || MaskVal >= 2 * VWidth) continue; if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) - UsedElts.setBit(MaskVal); + UnionUsedElts.setBit(MaskVal); if (Shuffle->getOperand(1) == V && ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) - UsedElts.setBit(MaskVal - VWidth); + UnionUsedElts.setBit(MaskVal - VWidth); } - break; + return true; } default: break; } - return UsedElts; + + return false; } /// Find union of elements of V demanded by all its users. @@ -370,7 +370,8 @@ static APInt findDemandedEltsByAllUsers(Value *V) { APInt UnionUsedElts(VWidth, 0); for (const Use &U : V->uses()) { if (Instruction *I = dyn_cast(U.getUser())) { - UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + if (!findDemandedEltsBySingleUser(V, I, UnionUsedElts)) + return APInt::getAllOnes(VWidth); } else { UnionUsedElts = APInt::getAllOnes(VWidth); break;