diff --git a/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h b/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h index 26f5fe185dd526..5fa7427b260319 100644 --- a/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h @@ -114,7 +114,7 @@ class NaryReassociatePass : public PassInfoMixin { bool doOneIteration(Function &F); // Reassociates I for better CSE. - Instruction *tryReassociate(Instruction *I); + Instruction *tryReassociate(Instruction *I, const SCEV *&OrigSCEV); // Reassociate GEP for better CSE. Instruction *tryReassociateGEP(GetElementPtrInst *GEP); diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index b1bfc03c1e739f..bc1b58611dd118 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -213,18 +213,6 @@ bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_, return Changed; } -// Explicitly list the instruction types NaryReassociate handles for now. -static bool isPotentiallyNaryReassociable(Instruction *I) { - switch (I->getOpcode()) { - case Instruction::Add: - case Instruction::GetElementPtr: - case Instruction::Mul: - return true; - default: - return false; - } -} - bool NaryReassociatePass::doOneIteration(Function &F) { bool Changed = false; SeenExprs.clear(); @@ -236,13 +224,8 @@ bool NaryReassociatePass::doOneIteration(Function &F) { BasicBlock *BB = Node->getBlock(); for (auto I = BB->begin(); I != BB->end(); ++I) { Instruction *OrigI = &*I; - - if (!SE->isSCEVable(OrigI->getType()) || - !isPotentiallyNaryReassociable(OrigI)) - continue; - - const SCEV *OrigSCEV = SE->getSCEV(OrigI); - if (Instruction *NewI = tryReassociate(OrigI)) { + const SCEV *OrigSCEV = nullptr; + if (Instruction *NewI = tryReassociate(OrigI, OrigSCEV)) { Changed = true; OrigI->replaceAllUsesWith(NewI); @@ -274,7 +257,7 @@ bool NaryReassociatePass::doOneIteration(Function &F) { // nary-gep.ll. if (NewSCEV != OrigSCEV) SeenExprs[OrigSCEV].push_back(WeakTrackingVH(NewI)); - } else + } else if (OrigSCEV) SeenExprs[OrigSCEV].push_back(WeakTrackingVH(OrigI)); } } @@ -286,16 +269,26 @@ bool NaryReassociatePass::doOneIteration(Function &F) { return Changed; } -Instruction *NaryReassociatePass::tryReassociate(Instruction *I) { +Instruction *NaryReassociatePass::tryReassociate(Instruction * I, + const SCEV *&OrigSCEV) { + + if (!SE->isSCEVable(I->getType())) + return nullptr; + switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: + OrigSCEV = SE->getSCEV(I); return tryReassociateBinaryOp(cast(I)); case Instruction::GetElementPtr: + OrigSCEV = SE->getSCEV(I); return tryReassociateGEP(cast(I)); default: - llvm_unreachable("should be filtered out by isPotentiallyNaryReassociable"); + return nullptr; } + + llvm_unreachable("should not be reached"); + return nullptr; } static bool isGEPFoldable(GetElementPtrInst *GEP,