Skip to content

Commit

Permalink
[LoopReroll] Rewrite induction variable rewriting.
Browse files Browse the repository at this point in the history
This gets rid of a bunch of weird special cases; instead, just use SCEV
rewriting for everything.  In addition to being simpler, this fixes a
bug where we would use the wrong stride in certain edge cases.

The one bit I'm not quite sure about is the trip count handling,
specifically the FIXME about overflow.  In general, I think we need to
widen the exit condition, but that's probably not profitable if the new
type isn't legal, so we probably need a check somewhere.  That said, I
don't think I'm making the existing problem any worse.

As a followup to this, a bunch of IV-related code in root-finding could
be cleaned up; with SCEV-based rewriting, there isn't any reason to
assume a loop will have exactly one or two PHI nodes.

Differential Revision: https://reviews.llvm.org/D45191

llvm-svn: 335400
  • Loading branch information
Eli Friedman committed Jun 22, 2018
1 parent 2cbf973 commit 203eaaf
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 253 deletions.
236 changes: 59 additions & 177 deletions llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ using namespace llvm;

STATISTIC(NumRerolledLoops, "Number of rerolled loops");

static cl::opt<unsigned>
MaxInc("max-reroll-increment", cl::init(2048), cl::Hidden,
cl::desc("The maximum increment for loop rerolling"));

static cl::opt<unsigned>
NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400),
cl::Hidden,
Expand Down Expand Up @@ -398,8 +394,8 @@ namespace {

/// Stage 3: Assuming validate() returned true, perform the
/// replacement.
/// @param IterCount The maximum iteration count of L.
void replace(const SCEV *IterCount);
/// @param BackedgeTakenCount The backedge-taken count of L.
void replace(const SCEV *BackedgeTakenCount);

protected:
using UsesTy = MapVector<Instruction *, BitVector>;
Expand Down Expand Up @@ -429,8 +425,7 @@ namespace {
bool instrDependsOn(Instruction *I,
UsesTy::iterator Start,
UsesTy::iterator End);
void replaceIV(Instruction *Inst, Instruction *IV, const SCEV *IterCount);
void updateNonLoopCtrlIncr();
void replaceIV(DAGRootSet &DRS, const SCEV *Start, const SCEV *IncrExpr);

LoopReroll *Parent;

Expand Down Expand Up @@ -483,8 +478,8 @@ namespace {
void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs);
void collectPossibleReductions(Loop *L,
ReductionTracker &Reductions);
bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount,
ReductionTracker &Reductions);
bool reroll(Instruction *IV, Loop *L, BasicBlock *Header,
const SCEV *BackedgeTakenCount, ReductionTracker &Reductions);
};

} // end anonymous namespace
Expand All @@ -511,48 +506,6 @@ static bool hasUsesOutsideLoop(Instruction *I, Loop *L) {
return false;
}

static const SCEVConstant *getIncrmentFactorSCEV(ScalarEvolution *SE,
const SCEV *SCEVExpr,
Instruction &IV) {
const SCEVMulExpr *MulSCEV = dyn_cast<SCEVMulExpr>(SCEVExpr);

// If StepRecurrence of a SCEVExpr is a constant (c1 * c2, c2 = sizeof(ptr)),
// Return c1.
if (!MulSCEV && IV.getType()->isPointerTy())
if (const SCEVConstant *IncSCEV = dyn_cast<SCEVConstant>(SCEVExpr)) {
const PointerType *PTy = cast<PointerType>(IV.getType());
Type *ElTy = PTy->getElementType();
const SCEV *SizeOfExpr =
SE->getSizeOfExpr(SE->getEffectiveSCEVType(IV.getType()), ElTy);
if (IncSCEV->getValue()->getValue().isNegative()) {
const SCEV *NewSCEV =
SE->getUDivExpr(SE->getNegativeSCEV(SCEVExpr), SizeOfExpr);
return dyn_cast<SCEVConstant>(SE->getNegativeSCEV(NewSCEV));
} else {
return dyn_cast<SCEVConstant>(SE->getUDivExpr(SCEVExpr, SizeOfExpr));
}
}

if (!MulSCEV)
return nullptr;

// If StepRecurrence of a SCEVExpr is a c * sizeof(x), where c is constant,
// Return c.
const SCEVConstant *CIncSCEV = nullptr;
for (const SCEV *Operand : MulSCEV->operands()) {
if (const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Operand)) {
CIncSCEV = Constant;
} else if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Operand)) {
Type *AllocTy;
if (!Unknown->isSizeOf(AllocTy))
break;
} else {
return nullptr;
}
}
return CIncSCEV;
}

// Check if an IV is only used to control the loop. There are two cases:
// 1. It only has one use which is loop increment, and the increment is only
// used by comparison and the PHI (could has sext with nsw in between), and the
Expand Down Expand Up @@ -633,16 +586,8 @@ void LoopReroll::collectPossibleIVs(Loop *L,
continue;
if (!PHISCEV->isAffine())
continue;
const SCEVConstant *IncSCEV = nullptr;
if (I->getType()->isPointerTy())
IncSCEV =
getIncrmentFactorSCEV(SE, PHISCEV->getStepRecurrence(*SE), *I);
else
IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE));
auto IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE));
if (IncSCEV) {
const APInt &AInt = IncSCEV->getValue()->getValue().abs();
if (IncSCEV->getValue()->isZero() || AInt.uge(MaxInc))
continue;
IVToIncMap[&*I] = IncSCEV->getValue()->getSExtValue();
LLVM_DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV
<< "\n");
Expand Down Expand Up @@ -1463,8 +1408,20 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
return true;
}

void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) {
void LoopReroll::DAGRootTracker::replace(const SCEV *BackedgeTakenCount) {
BasicBlock *Header = L->getHeader();

// Compute the start and increment for each BaseInst before we start erasing
// instructions.
SmallVector<const SCEV *, 8> StartExprs;
SmallVector<const SCEV *, 8> IncrExprs;
for (auto &DRS : RootSets) {
const SCEVAddRecExpr *IVSCEV =
cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst));
StartExprs.push_back(IVSCEV->getStart());
IncrExprs.push_back(SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), IVSCEV));
}

// Remove instructions associated with non-base iterations.
for (BasicBlock::reverse_iterator J = Header->rbegin(), JE = Header->rend();
J != JE;) {
Expand All @@ -1478,74 +1435,47 @@ void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) {
++J;
}

bool HasTwoIVs = LoopControlIV && LoopControlIV != IV;
// Rewrite each BaseInst using SCEV.
for (size_t i = 0, e = RootSets.size(); i != e; ++i)
// Insert the new induction variable.
replaceIV(RootSets[i], StartExprs[i], IncrExprs[i]);

if (HasTwoIVs) {
updateNonLoopCtrlIncr();
replaceIV(LoopControlIV, LoopControlIV, IterCount);
} else
// We need to create a new induction variable for each different BaseInst.
for (auto &DRS : RootSets)
// Insert the new induction variable.
replaceIV(DRS.BaseInst, IV, IterCount);
{ // Limit the lifetime of SCEVExpander.
BranchInst *BI = cast<BranchInst>(Header->getTerminator());
const DataLayout &DL = Header->getModule()->getDataLayout();
SCEVExpander Expander(*SE, DL, "reroll");
auto Zero = SE->getZero(BackedgeTakenCount->getType());
auto One = SE->getOne(BackedgeTakenCount->getType());
auto NewIVSCEV = SE->getAddRecExpr(Zero, One, L, SCEV::FlagAnyWrap);
Value *NewIV =
Expander.expandCodeFor(NewIVSCEV, BackedgeTakenCount->getType(),
Header->getFirstNonPHIOrDbg());
// FIXME: This arithmetic can overflow.
auto TripCount = SE->getAddExpr(BackedgeTakenCount, One);
auto ScaledTripCount = SE->getMulExpr(
TripCount, SE->getConstant(BackedgeTakenCount->getType(), Scale));
auto ScaledBECount = SE->getMinusSCEV(ScaledTripCount, One);
Value *TakenCount =
Expander.expandCodeFor(ScaledBECount, BackedgeTakenCount->getType(),
Header->getFirstNonPHIOrDbg());
Value *Cond =
new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, TakenCount, "exitcond");
BI->setCondition(Cond);

if (BI->getSuccessor(1) != Header)
BI->swapSuccessors();
}

SimplifyInstructionsInBlock(Header, TLI);
DeleteDeadPHIs(Header, TLI);
}

// For non-loop-control IVs, we only need to update the last increment
// with right amount, then we are done.
void LoopReroll::DAGRootTracker::updateNonLoopCtrlIncr() {
const SCEV *NewInc = nullptr;
for (auto *LoopInc : LoopIncs) {
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LoopInc);
const SCEVConstant *COp = nullptr;
if (GEP && LoopInc->getOperand(0)->getType()->isPointerTy()) {
COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1)));
} else {
COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(0)));
if (!COp)
COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1)));
}

assert(COp && "Didn't find constant operand of LoopInc!\n");

const APInt &AInt = COp->getValue()->getValue();
const SCEV *ScaleSCEV = SE->getConstant(COp->getType(), Scale);
if (AInt.isNegative()) {
NewInc = SE->getNegativeSCEV(COp);
NewInc = SE->getUDivExpr(NewInc, ScaleSCEV);
NewInc = SE->getNegativeSCEV(NewInc);
} else
NewInc = SE->getUDivExpr(COp, ScaleSCEV);

LoopInc->setOperand(1, dyn_cast<SCEVConstant>(NewInc)->getValue());
}
}

void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst,
Instruction *InstIV,
const SCEV *IterCount) {
void LoopReroll::DAGRootTracker::replaceIV(DAGRootSet &DRS,
const SCEV *Start,
const SCEV *IncrExpr) {
BasicBlock *Header = L->getHeader();
int64_t Inc = IVToIncMap[InstIV];
bool NeedNewIV = InstIV == LoopControlIV;
bool Negative = !NeedNewIV && Inc < 0;

const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(Inst));
const SCEV *Start = RealIVSCEV->getStart();

if (NeedNewIV)
Start = SE->getConstant(Start->getType(), 0);

const SCEV *SizeOfExpr = nullptr;
const SCEV *IncrExpr =
SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1);
if (auto *PTy = dyn_cast<PointerType>(Inst->getType())) {
Type *ElTy = PTy->getElementType();
SizeOfExpr =
SE->getSizeOfExpr(SE->getEffectiveSCEVType(Inst->getType()), ElTy);
IncrExpr = SE->getMulExpr(IncrExpr, SizeOfExpr);
}
Instruction *Inst = DRS.BaseInst;

const SCEV *NewIVSCEV =
SE->getAddRecExpr(Start, IncrExpr, L, SCEV::FlagAnyWrap);

Expand All @@ -1558,54 +1488,6 @@ void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst,
for (auto &KV : Uses)
if (KV.second.find_first() == 0)
KV.first->replaceUsesOfWith(Inst, NewIV);

if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) {
// FIXME: Why do we need this check?
if (Uses[BI].find_first() == IL_All) {
const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE);

if (NeedNewIV)
ICSCEV = SE->getMulExpr(IterCount,
SE->getConstant(IterCount->getType(), Scale));

// Iteration count SCEV minus or plus 1
const SCEV *MinusPlus1SCEV =
SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1);
if (Inst->getType()->isPointerTy()) {
assert(SizeOfExpr && "SizeOfExpr is not initialized");
MinusPlus1SCEV = SE->getMulExpr(MinusPlus1SCEV, SizeOfExpr);
}

const SCEV *ICMinusPlus1SCEV = SE->getMinusSCEV(ICSCEV, MinusPlus1SCEV);
// Iteration count minus 1
Instruction *InsertPtr = nullptr;
if (isa<SCEVConstant>(ICMinusPlus1SCEV)) {
InsertPtr = BI;
} else {
BasicBlock *Preheader = L->getLoopPreheader();
if (!Preheader)
Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA);
InsertPtr = Preheader->getTerminator();
}

if (!isa<PointerType>(NewIV->getType()) && NeedNewIV &&
(SE->getTypeSizeInBits(NewIV->getType()) <
SE->getTypeSizeInBits(ICMinusPlus1SCEV->getType()))) {
IRBuilder<> Builder(BI);
Builder.SetCurrentDebugLocation(BI->getDebugLoc());
NewIV = Builder.CreateSExt(NewIV, ICMinusPlus1SCEV->getType());
}
Value *ICMinusPlus1 = Expander.expandCodeFor(
ICMinusPlus1SCEV, NewIV->getType(), InsertPtr);

Value *Cond =
new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinusPlus1, "exitcond");
BI->setCondition(Cond);

if (BI->getSuccessor(1) != Header)
BI->swapSuccessors();
}
}
}
}

Expand Down Expand Up @@ -1722,7 +1604,7 @@ void LoopReroll::ReductionTracker::replaceSelected() {
// f(%iv) or part of some f(%iv.i). If all of that is true (and all reductions
// have been validated), then we reroll the loop.
bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header,
const SCEV *IterCount,
const SCEV *BackedgeTakenCount,
ReductionTracker &Reductions) {
DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA,
IVToIncMap, LoopControlIV);
Expand All @@ -1740,7 +1622,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header,
// making changes!

Reductions.replaceSelected();
DAGRoots.replace(IterCount);
DAGRoots.replace(BackedgeTakenCount);

++NumRerolledLoops;
return true;
Expand Down Expand Up @@ -1769,10 +1651,10 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) {
if (!SE->hasLoopInvariantBackedgeTakenCount(L))
return false;

const SCEV *LIBETC = SE->getBackedgeTakenCount(L);
const SCEV *IterCount = SE->getAddExpr(LIBETC, SE->getOne(LIBETC->getType()));
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
LLVM_DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n");
LLVM_DEBUG(dbgs() << "LRR: iteration count = " << *IterCount << "\n");
LLVM_DEBUG(dbgs() << "LRR: backedge-taken count = " << *BackedgeTakenCount
<< "\n");

// First, we need to find the induction variable with respect to which we can
// reroll (there may be several possible options).
Expand All @@ -1793,7 +1675,7 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) {
// For each possible IV, collect the associated possible set of 'root' nodes
// (i+1, i+2, etc.).
for (Instruction *PossibleIV : PossibleIVs)
if (reroll(PossibleIV, L, Header, IterCount, Reductions)) {
if (reroll(PossibleIV, L, Header, BackedgeTakenCount, Reductions)) {
Changed = true;
break;
}
Expand Down
Loading

0 comments on commit 203eaaf

Please sign in to comment.