Skip to content

Commit

Permalink
Fix a bug when unswitching on partial LIV for SwitchInst
Browse files Browse the repository at this point in the history
Summary: Fix a bug when unswitching on partial LIV for SwitchInst.

Reviewers: hfinkel, efriedma, sanjoy

Reviewed By: sanjoy

Subscribers: david2050, mzolotukhin, llvm-commits

Differential Revision: https://reviews.llvm.org/D29107

llvm-svn: 296363
  • Loading branch information
trentxintong committed Feb 27, 2017
1 parent 08d0840 commit 16b85a6
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 32 deletions.
160 changes: 128 additions & 32 deletions llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
Expand Up @@ -374,9 +374,27 @@ Pass *llvm::createLoopUnswitchPass(bool Os) {
return new LoopUnswitch(Os);
}

/// Operator chain lattice.
enum OperatorChain {
OC_OpChainNone, ///< There is no operator.
OC_OpChainOr, ///< There are only ORs.
OC_OpChainAnd, ///< There are only ANDs.
OC_OpChainMixed ///< There are ANDs and ORs.
};

/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
/// an invariant piece, return the invariant. Otherwise, return null.
//
/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a
/// mixed operator chain, as we can not reliably find a value which will simplify
/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0
/// to simplify the chain.
///
/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to
/// simplify the condition itself to a loop variant condition, but at the
/// cost of creating an entirely new loop.
static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
OperatorChain &ParentChain,
DenseMap<Value *, Value *> &Cache) {
auto CacheIt = Cache.find(Cond);
if (CacheIt != Cache.end())
Expand All @@ -400,31 +418,75 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,
return Cond;
}

// Walk up the operator chain to find partial invariant conditions.
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))
if (BO->getOpcode() == Instruction::And ||
BO->getOpcode() == Instruction::Or) {
// If either the left or right side is invariant, we can unswitch on this,
// which will cause the branch to go away in one loop and the condition to
// simplify in the other one.
if (Value *LHS =
FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) {
Cache[Cond] = LHS;
return LHS;
// Given the previous operator, compute the current operator chain status.
OperatorChain NewChain;
switch (ParentChain) {
case OC_OpChainNone:
NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
OC_OpChainOr;
break;
case OC_OpChainOr:
NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr :
OC_OpChainMixed;
break;
case OC_OpChainAnd:
NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd :
OC_OpChainMixed;
break;
case OC_OpChainMixed:
NewChain = OC_OpChainMixed;
break;
}
if (Value *RHS =
FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) {
Cache[Cond] = RHS;
return RHS;

// If we reach a Mixed state, we do not want to keep walking up as we can not
// reliably find a value that will simplify the chain. With this check, we
// will return null on the first sight of mixed chain and the caller will
// either backtrack to find partial LIV in other operand or return null.
if (NewChain != OC_OpChainMixed) {
// Update the current operator chain type before we search up the chain.
ParentChain = NewChain;
// If either the left or right side is invariant, we can unswitch on this,
// which will cause the branch to go away in one loop and the condition to
// simplify in the other one.
if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed,
ParentChain, Cache)) {
Cache[Cond] = LHS;
return LHS;
}
// We did not manage to find a partial LIV in operand(0). Backtrack and try
// operand(1).
ParentChain = NewChain;
if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed,
ParentChain, Cache)) {
Cache[Cond] = RHS;
return RHS;
}
}
}

Cache[Cond] = nullptr;
return nullptr;
}

static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
/// Cond is a condition that occurs in L. If it is invariant in the loop, or has
/// an invariant piece, return the invariant along with the operator chain type.
/// Otherwise, return null.
static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond,
Loop *L,
bool &Changed) {
DenseMap<Value *, Value *> Cache;
return FindLIVLoopCondition(Cond, L, Changed, Cache);
OperatorChain OpChain = OC_OpChainNone;
Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache);

// In case we do find a LIV, it can not be obtained by walking up a mixed
// operator chain.
assert((!FCond || OpChain != OC_OpChainMixed) &&
"Do not expect a partial LIV with mixed operator chain");
return {FCond, OpChain};
}

bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
Expand Down Expand Up @@ -556,7 +618,7 @@ bool LoopUnswitch::processCurrentLoop() {

for (IntrinsicInst *Guard : Guards) {
Value *LoopCond =
FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed);
FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first;
if (LoopCond &&
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {
// NB! Unswitching (if successful) could have erased some of the
Expand Down Expand Up @@ -597,32 +659,57 @@ bool LoopUnswitch::processCurrentLoop() {
// See if this, or some part of it, is loop invariant. If so, we can
// unswitch on it if we desire.
Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
currentLoop, Changed);
currentLoop, Changed).first;
if (LoopCond &&
UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {
++NumBranches;
return true;
}
}
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
currentLoop, Changed);
Value *SC = SI->getCondition();
Value *LoopCond;
OperatorChain OpChain;
std::tie(LoopCond, OpChain) =
FindLIVLoopCondition(SC, currentLoop, Changed);

unsigned NumCases = SI->getNumCases();
if (LoopCond && NumCases) {
// Find a value to unswitch on:
// FIXME: this should chose the most expensive case!
// FIXME: scan for a case with a non-critical edge?
Constant *UnswitchVal = nullptr;

// Do not process same value again and again.
// At this point we have some cases already unswitched and
// some not yet unswitched. Let's find the first not yet unswitched one.
for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
i != e; ++i) {
Constant *UnswitchValCandidate = i.getCaseValue();
if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
UnswitchVal = UnswitchValCandidate;
break;
// Find a case value such that at least one case value is unswitched
// out.
if (OpChain == OC_OpChainAnd) {
// If the chain only has ANDs and the switch has a case value of 0.
// Dropping in a 0 to the chain will unswitch out the 0-casevalue.
auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType()));
if (BranchesInfo.isUnswitched(SI, AllZero))
continue;
// We are unswitching 0 out.
UnswitchVal = AllZero;
} else if (OpChain == OC_OpChainOr) {
// If the chain only has ORs and the switch has a case value of ~0.
// Dropping in a ~0 to the chain will unswitch out the ~0-casevalue.
auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType()));
if (BranchesInfo.isUnswitched(SI, AllOne))
continue;
// We are unswitching ~0 out.
UnswitchVal = AllOne;
} else {
assert(OpChain == OC_OpChainNone &&
"Expect to unswitch on trivial chain");
// Do not process same value again and again.
// At this point we have some cases already unswitched and
// some not yet unswitched. Let's find the first not yet unswitched one.
for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
i != e; ++i) {
Constant *UnswitchValCandidate = i.getCaseValue();
if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) {
UnswitchVal = UnswitchValCandidate;
break;
}
}
}

Expand All @@ -631,6 +718,11 @@ bool LoopUnswitch::processCurrentLoop() {

if (UnswitchIfProfitable(LoopCond, UnswitchVal)) {
++NumSwitches;
// In case of a full LIV, UnswitchVal is the value we unswitched out.
// In case of a partial LIV, we only unswitch when its an AND-chain
// or OR-chain. In both cases switch input value simplifies to
// UnswitchVal.
BranchesInfo.setUnswitched(SI, UnswitchVal);
return true;
}
}
Expand All @@ -641,7 +733,7 @@ bool LoopUnswitch::processCurrentLoop() {
BBI != E; ++BBI)
if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
currentLoop, Changed);
currentLoop, Changed).first;
if (LoopCond && UnswitchIfProfitable(LoopCond,
ConstantInt::getTrue(Context))) {
++NumSelects;
Expand Down Expand Up @@ -900,7 +992,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
return false;

Value *LoopCond = FindLIVLoopCondition(BI->getCondition(),
currentLoop, Changed);
currentLoop, Changed).first;

// Unswitch only if the trivial condition itself is an LIV (not
// partial LIV which could occur in and/or)
Expand Down Expand Up @@ -931,7 +1023,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {
// If this isn't switching on an invariant condition, we can't unswitch it.
Value *LoopCond = FindLIVLoopCondition(SI->getCondition(),
currentLoop, Changed);
currentLoop, Changed).first;

// Unswitch only if the trivial condition itself is an LIV (not
// partial LIV which could occur in and/or)
Expand Down Expand Up @@ -969,6 +1061,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {

UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB,
nullptr);

// We are only unswitching full LIV.
BranchesInfo.setUnswitched(SI, CondVal);
++NumSwitches;
return true;
}
Expand Down Expand Up @@ -1250,6 +1345,9 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
SwitchInst *SI = dyn_cast<SwitchInst>(UI);
if (!SI || !isa<ConstantInt>(Val)) continue;

// NOTE: if a case value for the switch is unswitched out, we record it
// after the unswitch finishes. We can not record it here as the switch
// is not a direct user of the partial LIV.
SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val));
// Default case is live for multiple values.
if (DeadCase == SI->case_default()) continue;
Expand All @@ -1262,8 +1360,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
BasicBlock *SISucc = DeadCase.getCaseSuccessor();
BasicBlock *Latch = L->getLoopLatch();

BranchesInfo.setUnswitched(SI, Val);

if (!SI->findCaseDest(SISucc)) continue; // Edge is critical.
// If the DeadCase successor dominates the loop latch, then the
// transformation isn't safe since it will delete the sole predecessor edge
Expand Down

0 comments on commit 16b85a6

Please sign in to comment.