Skip to content

Commit

Permalink
[LoopReroll] Make root-finding more aggressive.
Browse files Browse the repository at this point in the history
Allow using an instruction other than a mul or phi as the base for
root-finding. For example, the included testcase includes a loop
which requires using a getelementptr as the base for root-finding.

Differential Revision: https://reviews.llvm.org/D26529

llvm-svn: 287588
  • Loading branch information
Eli Friedman committed Nov 21, 2016
1 parent 6cad011 commit c0bba1a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 50 deletions.
108 changes: 58 additions & 50 deletions llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,12 @@ namespace {
protected:
typedef MapVector<Instruction*, BitVector> UsesTy;

bool findRootsRecursive(Instruction *IVU,
void findRootsRecursive(Instruction *IVU,
SmallInstructionSet SubsumedInsts);
bool findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts);
bool collectPossibleRoots(Instruction *Base,
std::map<int64_t,Instruction*> &Roots);
bool validateRootSet(DAGRootSet &DRS);

bool collectUsedInstructions(SmallInstructionSet &PossibleRedSet);
void collectInLoopUserSet(const SmallInstructionVector &Roots,
Expand Down Expand Up @@ -827,7 +828,8 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) {
Roots[V] = cast<Instruction>(I);
}

if (Roots.empty())
// Make sure we have at least two roots.
if (Roots.empty() || (Roots.size() == 1 && BaseUsers.empty()))
return false;

// If we found non-loop-inc, non-root users of Base, assume they are
Expand Down Expand Up @@ -861,40 +863,61 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) {
return true;
}

bool LoopReroll::DAGRootTracker::
void LoopReroll::DAGRootTracker::
findRootsRecursive(Instruction *I, SmallInstructionSet SubsumedInsts) {
// Does the user look like it could be part of a root set?
// All its users must be simple arithmetic ops.
if (I->getNumUses() > IL_MaxRerollIterations)
return false;
return;

if ((I->getOpcode() == Instruction::Mul ||
I->getOpcode() == Instruction::PHI) &&
I != IV &&
findRootsBase(I, SubsumedInsts))
return true;
if (I != IV && findRootsBase(I, SubsumedInsts))
return;

SubsumedInsts.insert(I);

for (User *V : I->users()) {
Instruction *I = dyn_cast<Instruction>(V);
Instruction *I = cast<Instruction>(V);
if (is_contained(LoopIncs, I))
continue;

if (!I || !isSimpleArithmeticOp(I) ||
!findRootsRecursive(I, SubsumedInsts))
return false;
if (!isSimpleArithmeticOp(I))
continue;

// The recursive call makes a copy of SubsumedInsts.
findRootsRecursive(I, SubsumedInsts);
}
}

bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) {
if (DRS.Roots.empty())
return false;

// Consider a DAGRootSet with N-1 roots (so N different values including
// BaseInst).
// Define d = Roots[0] - BaseInst, which should be the same as
// Roots[I] - Roots[I-1] for all I in [1..N).
// Define D = BaseInst@J - BaseInst@J-1, where "@J" means the value at the
// loop iteration J.
//
// Now, For the loop iterations to be consecutive:
// D = d * N
const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst));
if (!ADR)
return false;
unsigned N = DRS.Roots.size() + 1;
const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), ADR);
const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N);
if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV))
return false;

return true;
}

bool LoopReroll::DAGRootTracker::
findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) {

// The base instruction needs to be a multiply so
// that we can erase it.
if (IVU->getOpcode() != Instruction::Mul &&
IVU->getOpcode() != Instruction::PHI)
// The base of a RootSet must be an AddRec, so it can be erased.
const auto *IVU_ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IVU));
if (!IVU_ADR || IVU_ADR->getLoop() != L)
return false;

std::map<int64_t, Instruction*> V;
Expand All @@ -910,6 +933,8 @@ findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) {
DAGRootSet DRS;
DRS.BaseInst = nullptr;

SmallVector<DAGRootSet, 16> PotentialRootSets;

for (auto &KV : V) {
if (!DRS.BaseInst) {
DRS.BaseInst = KV.second;
Expand All @@ -920,13 +945,22 @@ findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) {
DRS.Roots.push_back(KV.second);
} else {
// Linear sequence terminated.
RootSets.push_back(DRS);
if (!validateRootSet(DRS))
return false;

// Construct a new DAGRootSet with the next sequence.
PotentialRootSets.push_back(DRS);
DRS.BaseInst = KV.second;
DRS.SubsumedInsts = SubsumedInsts;
DRS.Roots.clear();
}
}
RootSets.push_back(DRS);

if (!validateRootSet(DRS))
return false;

PotentialRootSets.push_back(DRS);

RootSets.append(PotentialRootSets.begin(), PotentialRootSets.end());

return true;
}
Expand All @@ -940,8 +974,7 @@ bool LoopReroll::DAGRootTracker::findRoots() {
if (isLoopIncrement(IVU, IV))
LoopIncs.push_back(cast<Instruction>(IVU));
}
if (!findRootsRecursive(IV, SmallInstructionSet()))
return false;
findRootsRecursive(IV, SmallInstructionSet());
LoopIncs.push_back(IV);
} else {
if (!findRootsBase(IV, SmallInstructionSet()))
Expand All @@ -961,31 +994,6 @@ bool LoopReroll::DAGRootTracker::findRoots() {
}
}

// And ensure all loop iterations are consecutive. We rely on std::map
// providing ordered traversal.
for (auto &V : RootSets) {
const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(V.BaseInst));
if (!ADR)
return false;

// Consider a DAGRootSet with N-1 roots (so N different values including
// BaseInst).
// Define d = Roots[0] - BaseInst, which should be the same as
// Roots[I] - Roots[I-1] for all I in [1..N).
// Define D = BaseInst@J - BaseInst@J-1, where "@J" means the value at the
// loop iteration J.
//
// Now, For the loop iterations to be consecutive:
// D = d * N

unsigned N = V.Roots.size() + 1;
const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(V.Roots[0]), ADR);
const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N);
if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) {
DEBUG(dbgs() << "LRR: Aborting because iterations are not consecutive\n");
return false;
}
}
Scale = RootSets[0].Roots.size() + 1;

if (Scale > IL_MaxRerollIterations) {
Expand Down Expand Up @@ -1498,8 +1506,8 @@ void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst,
{ // Limit the lifetime of SCEVExpander.
const DataLayout &DL = Header->getModule()->getDataLayout();
SCEVExpander Expander(*SE, DL, "reroll");
Value *NewIV =
Expander.expandCodeFor(NewIVSCEV, InstIV->getType(), &Header->front());
Value *NewIV = Expander.expandCodeFor(NewIVSCEV, Inst->getType(),
Header->getFirstNonPHIOrDbg());

for (auto &KV : Uses)
if (KV.second.find_first() == 0)
Expand Down
31 changes: 31 additions & 0 deletions llvm/test/Transforms/LoopReroll/basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,37 @@ for.end: ; preds = %for.body
ret void
}

define void @gep-indexing(i32* nocapture %x) {
entry:
%call = tail call i32 @foo(i32 0) #1
br label %for.body

for.body: ; preds = %for.body, %entry
%indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
%0 = mul nsw i64 %indvars.iv, 3
%arrayidx = getelementptr inbounds i32, i32* %x, i64 %0
store i32 %call, i32* %arrayidx, align 4
%arrayidx4 = getelementptr inbounds i32, i32* %arrayidx, i64 1
store i32 %call, i32* %arrayidx4, align 4
%arrayidx9 = getelementptr inbounds i32, i32* %arrayidx, i64 2
store i32 %call, i32* %arrayidx9, align 4
%indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
%exitcond = icmp eq i64 %indvars.iv.next, 500
br i1 %exitcond, label %for.end, label %for.body

; CHECK-LABEL: @gep-indexing
; CHECK: for.body:
; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
; CHECK-NEXT: %scevgep = getelementptr i32, i32* %x, i64 %indvars.iv
; CHECK-NEXT: store i32 %call, i32* %scevgep, align 4
; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
; CHECK-NEXT: %exitcond2 = icmp eq i32* %scevgep, %scevgep1
; CHECK-NEXT: br i1 %exitcond2, label %for.end, label %for.body

for.end: ; preds = %for.body
ret void
}


define void @unordered_atomic_ops(i32* noalias %buf_0, i32* noalias %buf_1) {
; CHECK-LABEL: @unordered_atomic_ops(
Expand Down

0 comments on commit c0bba1a

Please sign in to comment.