diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index a5848646c0d084..5d00fa56e888bd 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -307,25 +307,6 @@ class LoopIdiomRecognizeLegacyPass : public LoopPass { } }; -// The Folder will fold expressions that are guarded by the loop entry. -class SCEVSignToZeroExtentionRewriter - : public SCEVRewriteVisitor { -public: - ScalarEvolution &SE; - const Loop *CurLoop; - SCEVSignToZeroExtentionRewriter(ScalarEvolution &SE, const Loop *CurLoop) - : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {} - - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - // If expression is guarded by CurLoop to be greater or equal to zero - // then convert sext to zext. Otherwise return the original expression. - if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr, - SE.getZero(Expr->getType()))) - return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType()); - return Expr; - } -}; - } // end anonymous namespace char LoopIdiomRecognizeLegacyPass::ID = 0; @@ -986,12 +967,12 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, << "\n"); if (PositiveStrideSCEV != MemsetSizeSCEV) { - // The folding is to fold an expression that is covered by the loop guard - // at loop entry. After the folding, compare again and proceed with - // optimization, if equal. - SCEVSignToZeroExtentionRewriter Folder(*SE, CurLoop); - const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV); - const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV); + // If an expression is covered by the loop guard, compare again and + // proceed with optimization if equal. + const SCEV *FoldedPositiveStride = + SE->applyLoopGuards(PositiveStrideSCEV, CurLoop); + const SCEV *FoldedMemsetSize = + SE->applyLoopGuards(MemsetSizeSCEV, CurLoop); LLVM_DEBUG(dbgs() << " Try to fold SCEV based on loop guard\n" << " FoldedMemsetSize: " << *FoldedMemsetSize << "\n"