Skip to content

Commit

Permalink
[LoopFlatten] Update MemorySSA state
Browse files Browse the repository at this point in the history
I would like to move LoopFlatten from LoopPass Manager LPM2 to LPM1 (D116612),
but that is a LPM that is using MemorySSA and so LoopFlatten needs to preserve
MemorySSA and this adds that. More specifically, LoopFlatten restructures the
CFG and with this change the MSSA state is updated accordingly, where we also
update the DomTree. LoopFlatten doesn't rewrite/optimise/delete load or store
instructions, so I have not added any MSSA updates for that.

Differential Revision: https://reviews.llvm.org/D116660
  • Loading branch information
Sjoerd Meijer committed Jan 19, 2022
1 parent 93e8cd2 commit d544a89
Showing 1 changed file with 55 additions and 23 deletions.
78 changes: 55 additions & 23 deletions llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Expand Up @@ -31,6 +31,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetTransformInfo.h"
Expand Down Expand Up @@ -70,8 +71,7 @@ static cl::opt<bool>
"trip counts will never overflow"));

static cl::opt<bool>
WidenIV("loop-flatten-widen-iv", cl::Hidden,
cl::init(true),
WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true),
cl::desc("Widen the loop induction variables, if possible, so "
"overflow checks won't reject flattening"));

Expand Down Expand Up @@ -100,7 +100,7 @@ struct FlattenInfo {
PHINode *NarrowInnerInductionPHI = nullptr;
PHINode *NarrowOuterInductionPHI = nullptr;

FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};

bool isNarrowInductionPhi(PHINode *Phi) {
// This can't be the narrow phi if we haven't widened the IV first.
Expand Down Expand Up @@ -207,7 +207,7 @@ static bool findLoopComponents(
// nothing obvious in the surrounding code when handles the overflow case.
// FIXME: audit code to establish whether there's a latent bug here.
const SCEV *SCEVTripCount =
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
Expand Down Expand Up @@ -611,7 +611,8 @@ static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,

static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
ScalarEvolution *SE, AssumptionCache *AC,
const TargetTransformInfo *TTI, LPMUpdater *U) {
const TargetTransformInfo *TTI, LPMUpdater *U,
MemorySSAUpdater *MSSAU) {
Function *F = FI.OuterLoop->getHeader()->getParent();
LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
{
Expand Down Expand Up @@ -647,7 +648,11 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
InnerExitingBlock->getTerminator()->eraseFromParent();
BranchInst::Create(InnerExitBlock, InnerExitingBlock);

// Update the DomTree and MemorySSA.
DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
if (MSSAU)
MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());

// Replace all uses of the polynomial calculated from the two induction
// variables with the one new one.
Expand All @@ -658,8 +663,8 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
"flatten.trunciv");

LLVM_DEBUG(dbgs() << "Replacing: "; V->dump();
dbgs() << "with: "; OuterValue->dump());
LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: ";
OuterValue->dump());
V->replaceAllUsesWith(OuterValue);
}

Expand Down Expand Up @@ -698,7 +703,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
// (OuterTripCount * InnerTripCount) as the new trip count is safe.
if (InnerType != OuterType ||
InnerType->getScalarSizeInBits() >= MaxLegalSize ||
MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
MaxLegalType->getScalarSizeInBits() <
InnerType->getScalarSizeInBits() * 2) {
LLVM_DEBUG(dbgs() << "Can't widen the IV\n");
return false;
}
Expand All @@ -708,10 +714,10 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
unsigned ElimExt = 0;
unsigned Widened = 0;

auto CreateWideIV = [&] (WideIVInfo WideIV, bool &Deleted) -> bool {
PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts,
ElimExt, Widened, true /* HasGuards */,
true /* UsePostIncrementRanges */);
auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool {
PHINode *WidePhi =
createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,
true /* HasGuards */, true /* UsePostIncrementRanges */);
if (!WidePhi)
return false;
LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
Expand All @@ -721,14 +727,14 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
};

bool Deleted;
if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false }, Deleted))
if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false}, Deleted))
return false;
// Add the narrow phi to list, so that it will be adjusted later when the
// the transformation is performed.
if (!Deleted)
FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);

if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false }, Deleted))
if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false}, Deleted))
return false;

assert(Widened && "Widened IV expected");
Expand All @@ -744,7 +750,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,

static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
ScalarEvolution *SE, AssumptionCache *AC,
const TargetTransformInfo *TTI, LPMUpdater *U) {
const TargetTransformInfo *TTI, LPMUpdater *U,
MemorySSAUpdater *MSSAU) {
LLVM_DEBUG(
dbgs() << "Loop flattening running on outer loop "
<< FI.OuterLoop->getHeader()->getName() << " and inner loop "
Expand Down Expand Up @@ -773,7 +780,7 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,

// If we have widened and can perform the transformation, do that here.
if (CanFlatten)
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U);
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);

// Otherwise, if we haven't widened the IV, check if the new iteration
// variable might overflow. In this case, we need to version the loop, and
Expand All @@ -791,18 +798,19 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
}

LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U);
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
}

bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U) {
AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U,
MemorySSAUpdater *MSSAU) {
bool Changed = false;
for (Loop *InnerLoop : LN.getLoops()) {
auto *OuterLoop = InnerLoop->getParentLoop();
if (!OuterLoop)
continue;
FlattenInfo FI(OuterLoop, InnerLoop);
Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U);
Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
}
return Changed;
}
Expand All @@ -813,16 +821,30 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,

bool Changed = false;

Optional<MemorySSAUpdater> MSSAU;
if (AR.MSSA) {
MSSAU = MemorySSAUpdater(AR.MSSA);
if (VerifyMemorySSA)
AR.MSSA->verifyMemorySSA();
}

// The loop flattening pass requires loops to be
// in simplified form, and also needs LCSSA. Running
// this pass will simplify all loops that contain inner loops,
// regardless of whether anything ends up being flattened.
Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U);
Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
MSSAU.hasValue() ? MSSAU.getPointer() : nullptr);

if (!Changed)
return PreservedAnalyses::all();

return getLoopPassPreservedAnalyses();
if (AR.MSSA && VerifyMemorySSA)
AR.MSSA->verifyMemorySSA();

auto PA = getLoopPassPreservedAnalyses();
if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
}

namespace {
Expand All @@ -842,6 +864,7 @@ class LoopFlattenLegacyPass : public FunctionPass {
AU.addPreserved<TargetTransformInfoWrapperPass>();
AU.addRequired<AssumptionCacheTracker>();
AU.addPreserved<AssumptionCacheTracker>();
AU.addPreserved<MemorySSAWrapperPass>();
}
};
} // namespace
Expand All @@ -854,7 +877,9 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
false, false)

FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
FunctionPass *llvm::createLoopFlattenPass() {
return new LoopFlattenLegacyPass();
}

bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
Expand All @@ -864,10 +889,17 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
auto *TTI = &TTIP.getTTI(F);
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>();

Optional<MemorySSAUpdater> MSSAU;
if (MSSA)
MSSAU = MemorySSAUpdater(&MSSA->getMSSA());

bool Changed = false;
for (Loop *L : *LI) {
auto LN = LoopNest::getLoopNest(*L, *SE);
Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr);
Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr,
MSSAU.hasValue() ? MSSAU.getPointer() : nullptr);
}
return Changed;
}

0 comments on commit d544a89

Please sign in to comment.