diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 4d9578934d9e6..c46db4e63bfee 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -149,6 +149,118 @@ struct FlattenInfo { return false; return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi; } + bool isInnerLoopIncrement(User *U) { + return InnerIncrement == U; + } + bool isOuterLoopIncrement(User *U) { + return OuterIncrement == U; + } + bool isInnerLoopTest(User *U) { + return InnerBranch->getCondition() == U; + } + + bool checkOuterInductionPhiUsers(SmallPtrSet &ValidOuterPHIUses) { + for (User *U : OuterInductionPHI->users()) { + if (isOuterLoopIncrement(U)) + continue; + + auto IsValidOuterPHIUses = [&] (User *U) -> bool { + LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); + if (!ValidOuterPHIUses.count(U)) { + LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); + return false; + } + LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + return true; + }; + + if (auto *V = dyn_cast(U)) { + for (auto *K : V->users()) { + if (!IsValidOuterPHIUses(K)) + return false; + } + continue; + } + + if (!IsValidOuterPHIUses(U)) + return false; + } + return true; + } + + bool matchLinearIVUser(User *U, Value *InnerTripCount, + SmallPtrSet &ValidOuterPHIUses) { + LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); + Value *MatchedMul = nullptr; + Value *MatchedItCount = nullptr; + + bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI), + m_Value(MatchedItCount))); + + // Matches the same pattern as above, except it also looks for truncs + // on the phi, which can be the result of widening the induction variables. + bool IsAddTrunc = + match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)), + m_Value(MatchedItCount))); + + if (!MatchedItCount) + return false; + + // Look through extends if the IV has been widened. + if (Widened && + (isa(MatchedItCount) || isa(MatchedItCount))) { + assert(MatchedItCount->getType() == InnerInductionPHI->getType() && + "Unexpected type mismatch in types after widening"); + MatchedItCount = isa(MatchedItCount) + ? dyn_cast(MatchedItCount)->getOperand(0) + : dyn_cast(MatchedItCount)->getOperand(0); + } + + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { + LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + ValidOuterPHIUses.insert(MatchedMul); + LinearIVUses.insert(U); + return true; + } + + LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); + return false; + } + + bool checkInnerInductionPhiUsers(SmallPtrSet &ValidOuterPHIUses) { + Value *SExtInnerTripCount = InnerTripCount; + if (Widened && + (isa(InnerTripCount) || isa(InnerTripCount))) + SExtInnerTripCount = cast(InnerTripCount)->getOperand(0); + + for (User *U : InnerInductionPHI->users()) { + if (isInnerLoopIncrement(U)) + continue; + + // After widening the IVs, a trunc instruction might have been introduced, + // so look through truncs. + if (isa(U)) { + if (!U->hasOneUse()) + return false; + U = *U->user_begin(); + } + + // If the use is in the compare (which is also the condition of the inner + // branch) then the compare has been altered by another transformation e.g + // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is + // a constant. Ignore this use as the compare gets removed later anyway. + if (isInnerLoopTest(U)) + continue; + + if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) + return false; + } + return true; + } }; static bool @@ -162,6 +274,77 @@ setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, return true; } +// Given the RHS of the loop latch compare instruction, verify with SCEV +// that this is indeed the loop tripcount. +// TODO: This used to be a straightforward check but has grown to be quite +// complicated now. It is therefore worth revisiting what the additional +// benefits are of this (compared to relying on canonical loops and pattern +// matching). +static bool verifyTripCount(Value *RHS, Loop *L, + SmallPtrSetImpl &IterationInstructions, + PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, + BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa(BackedgeTakenCount)) { + LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); + return false; + } + + // The Extend=false flag is used for getTripCountFromExitCount as we want + // to verify and match it with the pattern matched tripcount. Please note + // that overflow checks are performed in checkOverflow, but are first tried + // to avoid by widening the IV. + const SCEV *SCEVTripCount = + SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false); + + const SCEV *SCEVRHS = SE->getSCEV(RHS); + if (SCEVRHS == SCEVTripCount) + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); + ConstantInt *ConstantRHS = dyn_cast(RHS); + if (ConstantRHS) { + const SCEV *BackedgeTCExt = nullptr; + if (IsWidened) { + const SCEV *SCEVTripCountExt; + // Find the extended backedge taken count and extended trip count using + // SCEV. One of these should now match the RHS of the compare. + BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); + if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + } + // If the RHS of the compare is equal to the backedge taken count we need + // to add one to get the trip count. + if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { + ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1); + Value *NewRHS = ConstantInt::get( + ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue()); + return setLoopComponents(NewRHS, TripCount, Increment, + IterationInstructions); + } + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); + } + // If the RHS isn't a constant then check that the reason it doesn't match + // the SCEV trip count is because the RHS is a ZExt or SExt instruction + // (and take the trip count to be the RHS). + if (!IsWidened) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + auto *TripCountInst = dyn_cast(RHS); + if (!TripCountInst) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + if ((!isa(TripCountInst) && !isa(TripCountInst)) || + SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); +} + // Finds the induction variable, increment and trip count for a simple loop that // we can flatten. static bool findLoopComponents( @@ -238,63 +421,9 @@ static bool findLoopComponents( // another transformation has changed the compare (e.g. icmp ult %inc, // tripcount -> icmp ult %j, tripcount-1), or both. Value *RHS = Compare->getOperand(1); - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); - if (isa(BackedgeTakenCount)) { - LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); - return false; - } - // The use of the Extend=false flag on getTripCountFromExitCount was added - // during a refactoring to preserve existing behavior. However, there's - // nothing obvious in the surrounding code when handles the overflow case. - // FIXME: audit code to establish whether there's a latent bug here. - const SCEV *SCEVTripCount = - SE->getTripCountFromExitCount(BackedgeTakenCount, false); - const SCEV *SCEVRHS = SE->getSCEV(RHS); - if (SCEVRHS == SCEVTripCount) - return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); - ConstantInt *ConstantRHS = dyn_cast(RHS); - if (ConstantRHS) { - const SCEV *BackedgeTCExt = nullptr; - if (IsWidened) { - const SCEV *SCEVTripCountExt; - // Find the extended backedge taken count and extended trip count using - // SCEV. One of these should now match the RHS of the compare. - BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); - SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); - if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { - LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); - return false; - } - } - // If the RHS of the compare is equal to the backedge taken count we need - // to add one to get the trip count. - if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { - ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1); - Value *NewRHS = ConstantInt::get( - ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue()); - return setLoopComponents(NewRHS, TripCount, Increment, - IterationInstructions); - } - return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); - } - // If the RHS isn't a constant then check that the reason it doesn't match - // the SCEV trip count is because the RHS is a ZExt or SExt instruction - // (and take the trip count to be the RHS). - if (!IsWidened) { - LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); - return false; - } - auto *TripCountInst = dyn_cast(RHS); - if (!TripCountInst) { - LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); - return false; - } - if ((!isa(TripCountInst) && !isa(TripCountInst)) || - SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; - } - return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); + + return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount, + Increment, BackBranch, SE, IsWidened); } static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { @@ -440,108 +569,26 @@ checkOuterLoopInsts(FlattenInfo &FI, return true; } -static bool checkIVUsers(FlattenInfo &FI) { - // We require all uses of both induction variables to match this pattern: - // - // (OuterPHI * InnerTripCount) + InnerPHI - // - // Any uses of the induction variables not matching that pattern would - // require a div/mod to reconstruct in the flattened loop, so the - // transformation wouldn't be profitable. - - Value *InnerTripCount = FI.InnerTripCount; - if (FI.Widened && - (isa(InnerTripCount) || isa(InnerTripCount))) - InnerTripCount = cast(InnerTripCount)->getOperand(0); + +// We require all uses of both induction variables to match this pattern: +// +// (OuterPHI * InnerTripCount) + InnerPHI +// +// Any uses of the induction variables not matching that pattern would +// require a div/mod to reconstruct in the flattened loop, so the +// transformation wouldn't be profitable. +static bool checkIVUsers(FlattenInfo &FI) { // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. SmallPtrSet ValidOuterPHIUses; - for (User *U : FI.InnerInductionPHI->users()) { - if (U == FI.InnerIncrement) - continue; - - // After widening the IVs, a trunc instruction might have been introduced, - // so look through truncs. - if (isa(U)) { - if (!U->hasOneUse()) - return false; - U = *U->user_begin(); - } - - // If the use is in the compare (which is also the condition of the inner - // branch) then the compare has been altered by another transformation e.g - // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is - // a constant. Ignore this use as the compare gets removed later anyway. - if (U == FI.InnerBranch->getCondition()) - continue; - - LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); - - Value *MatchedMul = nullptr; - Value *MatchedItCount = nullptr; - bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), - m_Value(MatchedMul))) && - match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI), - m_Value(MatchedItCount))); - - // Matches the same pattern as above, except it also looks for truncs - // on the phi, which can be the result of widening the induction variables. - bool IsAddTrunc = - match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), - m_Value(MatchedMul))) && - match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), - m_Value(MatchedItCount))); - - if (!MatchedItCount) - return false; - // Look through extends if the IV has been widened. - if (FI.Widened && - (isa(MatchedItCount) || isa(MatchedItCount))) { - assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() && - "Unexpected type mismatch in types after widening"); - MatchedItCount = isa(MatchedItCount) - ? dyn_cast(MatchedItCount)->getOperand(0) - : dyn_cast(MatchedItCount)->getOperand(0); - } - - if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { - LLVM_DEBUG(dbgs() << "Use is optimisable\n"); - ValidOuterPHIUses.insert(MatchedMul); - FI.LinearIVUses.insert(U); - } else { - LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); - return false; - } - } + if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses)) + return false; // Check that there are no uses of the outer IV other than the ones found // as part of the pattern above. - for (User *U : FI.OuterInductionPHI->users()) { - if (U == FI.OuterIncrement) - continue; - - auto IsValidOuterPHIUses = [&] (User *U) -> bool { - LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); - if (!ValidOuterPHIUses.count(U)) { - LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); - return false; - } - LLVM_DEBUG(dbgs() << "Use is optimisable\n"); - return true; - }; - - if (auto *V = dyn_cast(U)) { - for (auto *K : V->users()) { - if (!IsValidOuterPHIUses(K)) - return false; - } - continue; - } - - if (!IsValidOuterPHIUses(U)) - return false; - } + if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses)) + return false; LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n"; dbgs() << "Found " << FI.LinearIVUses.size()