diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 804fe1018b0e3..f57ee475fcc3e 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -225,7 +225,7 @@ class LoopIdiomRecognize { bool IsNegStride, bool IsLoopMemset = false); bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount); bool processLoopStoreOfLoopLoad(Value *DestPtr, Value *SourcePtr, - unsigned StoreSize, MaybeAlign StoreAlign, + const SCEV *StoreSize, MaybeAlign StoreAlign, MaybeAlign LoadAlign, Instruction *TheStore, Instruction *TheLoad, const SCEVAddRecExpr *StoreEv, @@ -858,15 +858,15 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, // Check if the stride matches the size of the memcpy. If so, then we know // that every byte is touched in the loop. - const SCEVConstant *StoreStride = + const SCEVConstant *ConstStoreStride = dyn_cast(StoreEv->getOperand(1)); - const SCEVConstant *LoadStride = + const SCEVConstant *ConstLoadStride = dyn_cast(LoadEv->getOperand(1)); - if (!StoreStride || !LoadStride) + if (!ConstStoreStride || !ConstLoadStride) return false; - APInt StoreStrideValue = StoreStride->getAPInt(); - APInt LoadStrideValue = LoadStride->getAPInt(); + APInt StoreStrideValue = ConstStoreStride->getAPInt(); + APInt LoadStrideValue = ConstLoadStride->getAPInt(); // Huge stride value - give up if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64) return false; @@ -888,9 +888,10 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, if (StoreStrideInt != LoadStrideInt) return false; - return processLoopStoreOfLoopLoad(Dest, Source, (unsigned)SizeInBytes, - MCI->getDestAlign(), MCI->getSourceAlign(), - MCI, MCI, StoreEv, LoadEv, BECount); + return processLoopStoreOfLoopLoad( + Dest, Source, SE->getConstant(Dest->getType(), SizeInBytes), + MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, StoreEv, LoadEv, + BECount); } /// processLoopMemSet - See if this memset can be promoted to a large memset. @@ -1242,16 +1243,18 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // random load we can't handle. Value *LoadPtr = LI->getPointerOperand(); const SCEVAddRecExpr *LoadEv = cast(SE->getSCEV(LoadPtr)); - return processLoopStoreOfLoopLoad(StorePtr, LoadPtr, StoreSize, + + const SCEV *StoreSizeSCEV = SE->getConstant(StorePtr->getType(), StoreSize); + return processLoopStoreOfLoopLoad(StorePtr, LoadPtr, StoreSizeSCEV, SI->getAlign(), LI->getAlign(), SI, LI, StoreEv, LoadEv, BECount); } bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( - Value *DestPtr, Value *SourcePtr, unsigned StoreSize, MaybeAlign StoreAlign, - MaybeAlign LoadAlign, Instruction *TheStore, Instruction *TheLoad, - const SCEVAddRecExpr *StoreEv, const SCEVAddRecExpr *LoadEv, - const SCEV *BECount) { + Value *DestPtr, Value *SourcePtr, const SCEV *StoreSizeSCEV, + MaybeAlign StoreAlign, MaybeAlign LoadAlign, Instruction *TheStore, + Instruction *TheLoad, const SCEVAddRecExpr *StoreEv, + const SCEVAddRecExpr *LoadEv, const SCEV *BECount) { // FIXME: until llvm.memcpy.inline supports dynamic sizes, we need to // conservatively bail here, since otherwise we may have to transform @@ -1274,9 +1277,14 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( Type *IntIdxTy = Builder.getIntNTy(DL->getIndexSizeInBits(StrAS)); APInt Stride = getStoreStride(StoreEv); + const SCEVConstant *ConstStoreSize = dyn_cast(StoreSizeSCEV); + + // TODO: Deal with non-constant size; Currently expect constant store size + assert(ConstStoreSize && "store size is expected to be a constant"); + + int64_t StoreSize = ConstStoreSize->getValue()->getZExtValue(); bool IsNegStride = StoreSize == -Stride; - const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); // Handle negative strided loops. if (IsNegStride) StrStart = @@ -1376,8 +1384,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // Okay, everything is safe, we can transform this! - const SCEV *NumBytesS = getNumBytes( - BECount, IntIdxTy, SE->getConstant(IntIdxTy, StoreSize), CurLoop, DL, SE); + const SCEV *NumBytesS = + getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());