diff --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp index b69e9f79719b01..016aa4cbba8b58 100644 --- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp +++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp @@ -269,20 +269,32 @@ static std::string getInstrName(const Value *I, StringRef Suffix) { return ""; } -static Value *GetPointerOperand(Value *MemI) { +static Value *getPointerOperandAndType(Value *MemI, + Type **PtrElementType = nullptr) { + + Value *PtrValue = nullptr; + Type *PointerElementType = nullptr; + if (LoadInst *LMemI = dyn_cast(MemI)) { - return LMemI->getPointerOperand(); + PtrValue = LMemI->getPointerOperand(); + PointerElementType = LMemI->getType(); } else if (StoreInst *SMemI = dyn_cast(MemI)) { - return SMemI->getPointerOperand(); + PtrValue = SMemI->getPointerOperand(); + PointerElementType = SMemI->getValueOperand()->getType(); } else if (IntrinsicInst *IMemI = dyn_cast(MemI)) { + PointerElementType = Type::getInt8Ty(MemI->getContext()); if (IMemI->getIntrinsicID() == Intrinsic::prefetch || - IMemI->getIntrinsicID() == Intrinsic::ppc_vsx_lxvp) - return IMemI->getArgOperand(0); - if (IMemI->getIntrinsicID() == Intrinsic::ppc_vsx_stxvp) - return IMemI->getArgOperand(1); + IMemI->getIntrinsicID() == Intrinsic::ppc_vsx_lxvp) { + PtrValue = IMemI->getArgOperand(0); + } else if (IMemI->getIntrinsicID() == Intrinsic::ppc_vsx_stxvp) { + PtrValue = IMemI->getArgOperand(1); + } } + /*Get ElementType if PtrElementType is not null.*/ + if (PtrElementType) + *PtrElementType = PointerElementType; - return nullptr; + return PtrValue; } bool PPCLoopInstrFormPrep::runOnFunction(Function &F) { @@ -309,7 +321,7 @@ bool PPCLoopInstrFormPrep::runOnFunction(Function &F) { void PPCLoopInstrFormPrep::addOneCandidate(Instruction *MemI, const SCEV *LSCEV, SmallVector &Buckets, unsigned MaxCandidateNum) { - assert((MemI && GetPointerOperand(MemI)) && + assert((MemI && getPointerOperandAndType(MemI)) && "Candidate should be a memory instruction."); assert(LSCEV && "Invalid SCEV for Ptr value."); bool FoundBucket = false; @@ -337,27 +349,14 @@ SmallVector PPCLoopInstrFormPrep::collectCandidates( SmallVector Buckets; for (const auto &BB : L->blocks()) for (auto &J : *BB) { - Value *PtrValue; - Type *PointerElementType; - - if (LoadInst *LMemI = dyn_cast(&J)) { - PtrValue = LMemI->getPointerOperand(); - PointerElementType = LMemI->getType(); - } else if (StoreInst *SMemI = dyn_cast(&J)) { - PtrValue = SMemI->getPointerOperand(); - PointerElementType = SMemI->getValueOperand()->getType(); - } else if (IntrinsicInst *IMemI = dyn_cast(&J)) { - PointerElementType = Type::getInt8Ty(J.getContext()); - if (IMemI->getIntrinsicID() == Intrinsic::prefetch || - IMemI->getIntrinsicID() == Intrinsic::ppc_vsx_lxvp) { - PtrValue = IMemI->getArgOperand(0); - } else if (IMemI->getIntrinsicID() == Intrinsic::ppc_vsx_stxvp) { - PtrValue = IMemI->getArgOperand(1); - } else continue; - } else continue; - - unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); - if (PtrAddrSpace) + Value *PtrValue = nullptr; + Type *PointerElementType = nullptr; + PtrValue = getPointerOperandAndType(&J, &PointerElementType); + + if (!PtrValue) + continue; + + if (PtrValue->getType()->getPointerAddressSpace()) continue; if (L->isLoopInvariant(PtrValue)) @@ -505,7 +504,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain, // The instruction corresponding to the Bucket's BaseSCEV must be the first // in the vector of elements. Instruction *MemI = BucketChain.Elements.begin()->Instr; - Value *BasePtr = GetPointerOperand(MemI); + Value *BasePtr = getPointerOperandAndType(MemI); assert(BasePtr && "No pointer operand"); Type *I8Ty = Type::getInt8Ty(MemI->getParent()->getContext()); @@ -627,7 +626,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain, for (auto I = std::next(BucketChain.Elements.begin()), IE = BucketChain.Elements.end(); I != IE; ++I) { - Value *Ptr = GetPointerOperand(I->Instr); + Value *Ptr = getPointerOperandAndType(I->Instr); assert(Ptr && "No pointer operand"); if (NewPtrs.count(Ptr)) continue;