Skip to content

Commit

Permalink
[LoopPred] Stop passing around builders [NFC]
Browse files Browse the repository at this point in the history
This is a preparatory patch for D60093. This patch itself is NFC, but while preparing this I noticed and committed a small hoisting change in rL358419.

The basic structure of the new scheme is that we pass around the guard ("the using instruction"), and select an optimal insert point by examining operands at each construction point. This seems conceptually a bit cleaner to start with as it isolates the knowledge about insertion safety at the actual insertion point.

Note that the non-hoisting path is not actually used at the moment. That's not exercised until D60093 is rebased on this one.

Differential Revision: https://reviews.llvm.org/D60718

llvm-svn: 358434
  • Loading branch information
preames committed Apr 15, 2019
1 parent e1e1bd7 commit e46d77d
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions llvm/lib/Transforms/Scalar/LoopPredication.cpp
Expand Up @@ -269,24 +269,29 @@ class LoopPredication {
/// trivial result would be the at the User itself, but we try to return a
/// loop invariant location if possible.
Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
/// Same as above, *except* that this uses the SCEV definition of invariant
/// which is that an expression *can be made* invariant via SCEVExpander.
/// Thus, this version is only suitable for finding an insert point to be be
/// passed to SCEVExpander!
Instruction *findInsertPt(Instruction *User, ArrayRef<const SCEV*> Ops);

bool CanExpand(const SCEV* S);
Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder,
Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS);

Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
IRBuilder<> &Builder);
Instruction *Guard);
Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck,
LoopICmp RangeCheck,
SCEVExpander &Expander,
IRBuilder<> &Builder);
Instruction *Guard);
Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
LoopICmp RangeCheck,
SCEVExpander &Expander,
IRBuilder<> &Builder);
Instruction *Guard);
unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
SCEVExpander &Expander, IRBuilder<> &Builder);
SCEVExpander &Expander, Instruction *Guard);
bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
// If the loop always exits through another block in the loop, we should not
Expand Down Expand Up @@ -394,21 +399,24 @@ LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
}

Value *LoopPredication::expandCheck(SCEVExpander &Expander,
IRBuilder<> &Builder,
Instruction *Guard,
ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS) {
Type *Ty = LHS->getType();
assert(Ty == RHS->getType() && "expandCheck operands have different types?");

if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
return Builder.getTrue();
if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
LHS, RHS))
return Builder.getFalse();
if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
IRBuilder<> Builder(Guard);
if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
return Builder.getTrue();
if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
LHS, RHS))
return Builder.getFalse();
}

Instruction *InsertAt = &*Builder.GetInsertPoint();
Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt);
Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt);
Value *LHSV = Expander.expandCodeFor(LHS, Ty, findInsertPt(Guard, {LHS}));
Value *RHSV = Expander.expandCodeFor(RHS, Ty, findInsertPt(Guard, {RHS}));
IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
return Builder.CreateICmp(Pred, LHSV, RHSV);
}

Expand Down Expand Up @@ -452,13 +460,22 @@ Instruction *LoopPredication::findInsertPt(Instruction *Use,
return Preheader->getTerminator();
}

Instruction *LoopPredication::findInsertPt(Instruction *Use,
ArrayRef<const SCEV*> Ops) {
for (const SCEV *Op : Ops)
if (!SE->isLoopInvariant(Op, L))
return Use;
return Preheader->getTerminator();
}


bool LoopPredication::CanExpand(const SCEV* S) {
return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
}

Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
SCEVExpander &Expander, IRBuilder<> &Builder) {
SCEVExpander &Expander, Instruction *Guard) {
auto *Ty = RangeCheck.IV->getType();
// Generate the widened condition for the forward loop:
// guardStart u< guardLimit &&
Expand Down Expand Up @@ -488,15 +505,16 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");

auto *LimitCheck =
expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS);
auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred,
expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
GuardStart, GuardLimit);
IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
}

Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
SCEVExpander &Expander, IRBuilder<> &Builder) {
SCEVExpander &Expander, Instruction *Guard) {
auto *Ty = RangeCheck.IV->getType();
const SCEV *GuardStart = RangeCheck.IV->getStart();
const SCEV *GuardLimit = RangeCheck.Limit;
Expand All @@ -522,10 +540,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
// See the header comment for reasoning of the checks.
auto LimitCheckPred =
ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred);
auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
auto *FirstIterationCheck = expandCheck(Expander, Guard,
ICmpInst::ICMP_ULT,
GuardStart, GuardLimit);
auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
SE->getOne(Ty));
IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
}

Expand All @@ -534,7 +554,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
/// returns None.
Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
SCEVExpander &Expander,
IRBuilder<> &Builder) {
Instruction *Guard) {
LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
LLVM_DEBUG(ICI->dump());

Expand Down Expand Up @@ -588,18 +608,18 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,

if (Step->isOne())
return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
Expander, Builder);
Expander, Guard);
else {
assert(Step->isAllOnesValue() && "Step should be -1!");
return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
Expander, Builder);
Expander, Guard);
}
}

unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
Value *Condition,
SCEVExpander &Expander,
IRBuilder<> &Builder) {
Instruction *Guard) {
unsigned NumWidened = 0;
// The guard condition is expected to be in form of:
// cond1 && cond2 && cond3 ...
Expand Down Expand Up @@ -631,7 +651,7 @@ unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,

if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
Builder)) {
Guard)) {
Checks.push_back(NewRangeCheck.getValue());
NumWidened++;
continue;
Expand All @@ -657,16 +677,15 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,

TotalConsidered++;
SmallVector<Value *, 4> Checks;
IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander,
Builder);
Guard);
if (NumWidened == 0)
return false;

TotalWidened += NumWidened;

// Emit the new guard condition
Builder.SetInsertPoint(findInsertPt(Guard, Checks));
IRBuilder<> Builder(findInsertPt(Guard, Checks));
Value *LastCheck = nullptr;
for (auto *Check : Checks)
if (!LastCheck)
Expand All @@ -689,16 +708,15 @@ bool LoopPredication::widenWidenableBranchGuardConditions(

TotalConsidered++;
SmallVector<Value *, 4> Checks;
IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
Expander, Builder);
Expander, BI);
if (NumWidened == 0)
return false;

TotalWidened += NumWidened;

// Emit the new guard condition
Builder.SetInsertPoint(findInsertPt(BI, Checks));
IRBuilder<> Builder(findInsertPt(BI, Checks));
Value *LastCheck = nullptr;
for (auto *Check : Checks)
if (!LastCheck)
Expand Down

0 comments on commit e46d77d

Please sign in to comment.