Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 45 additions & 43 deletions llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ struct AllSwitchPaths {
AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE,
LoopInfo *LI, Loop *L)
: Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), ORE(ORE),
LI(LI), SwitchOuterLoop(L) {}
DefaultDest(nullptr), LI(LI), SwitchOuterLoop(L) {}

std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; }
unsigned getNumThreadingPaths() { return TPaths.size(); }
Expand All @@ -587,6 +587,30 @@ struct AllSwitchPaths {
unifyTPaths();
}

/// Fast helper to get the successor corresponding to a particular case value
/// for a switch statement.
BasicBlock *getNextCaseSuccessor(const APInt &NextState) {
// Precompute the value => successor mapping
if (CaseValToDest.empty()) {
for (auto Case : Switch->cases()) {
APInt CaseVal = Case.getCaseValue()->getValue();
CaseValToDest[CaseVal] = Case.getCaseSuccessor();
}
DefaultDest = Switch->getDefaultDest();
}

auto SuccIt = CaseValToDest.find(NextState);
return SuccIt == CaseValToDest.end() ? DefaultDest : SuccIt->second;
}

void updateDefaultDest(BasicBlock *DefaultDest) {
this->DefaultDest = DefaultDest;
}

void updateNextCase(const APInt &NextState, BasicBlock *NextCase) {
CaseValToDest[NextState] = NextCase;
}

private:
// Value: an instruction that defines a switch state;
// Key: the parent basic block of that instruction.
Expand Down Expand Up @@ -818,22 +842,6 @@ struct AllSwitchPaths {
TPaths = std::move(TempList);
}

/// Fast helper to get the successor corresponding to a particular case value
/// for a switch statement.
BasicBlock *getNextCaseSuccessor(const APInt &NextState) {
// Precompute the value => successor mapping
if (CaseValToDest.empty()) {
for (auto Case : Switch->cases()) {
APInt CaseVal = Case.getCaseValue()->getValue();
CaseValToDest[CaseVal] = Case.getCaseSuccessor();
}
}

auto SuccIt = CaseValToDest.find(NextState);
return SuccIt == CaseValToDest.end() ? Switch->getDefaultDest()
: SuccIt->second;
}

// Two states are equivalent if they have the same switch destination.
// Unify the states in different threading path if the states are equivalent.
void unifyTPaths() {
Expand All @@ -858,6 +866,7 @@ struct AllSwitchPaths {
OptimizationRemarkEmitter *ORE;
std::vector<ThreadingPath> TPaths;
DenseMap<APInt, BasicBlock *> CaseValToDest;
BasicBlock *DefaultDest;
LoopInfo *LI;
Loop *SwitchOuterLoop;
};
Expand Down Expand Up @@ -1159,24 +1168,6 @@ struct TransformDFA {
SSAUpdate.RewriteAllUses(&DTU->getDomTree());
}

/// Helper to get the successor corresponding to a particular case value for
/// a switch statement.
/// TODO: Unify it with SwitchPaths->getNextCaseSuccessor(SwitchInst *Switch)
/// by updating cached value => successor mapping during threading.
static BasicBlock *getNextCaseSuccessor(SwitchInst *Switch,
const APInt &NextState) {
BasicBlock *NextCase = nullptr;
for (auto Case : Switch->cases()) {
if (Case.getCaseValue()->getValue() == NextState) {
NextCase = Case.getCaseSuccessor();
break;
}
}
if (!NextCase)
NextCase = Switch->getDefaultDest();
return NextCase;
}

/// Clones a basic block, and adds it to the CFG.
///
/// This function also includes updating phi nodes in the successors of the
Expand Down Expand Up @@ -1231,8 +1222,7 @@ struct TransformDFA {
// If BB is the last block in the path, we can simply update the one case
// successor that will be reached.
if (BB == SwitchPaths->getSwitchBlock()) {
SwitchInst *Switch = SwitchPaths->getSwitchInst();
BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState);
BlocksToUpdate.push_back(NextCase);
BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
if (ClonedSucc)
Expand Down Expand Up @@ -1283,6 +1273,15 @@ struct TransformDFA {
return;

Instruction *PrevTerm = PrevBB->getTerminator();
// Update cached value => destination mapping.
if (PrevTerm == SwitchPaths->getSwitchInst()) {
for (auto Case : SwitchPaths->getSwitchInst()->cases())
if (Case.getCaseSuccessor() == OldBB)
SwitchPaths->updateNextCase(Case.getCaseValue()->getValue(), NewBB);
if (SwitchPaths->getSwitchInst()->getDefaultDest() == OldBB)
SwitchPaths->updateDefaultDest(NewBB);
}
// Replace actual successors.
for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) {
if (PrevTerm->getSuccessor(Idx) == OldBB) {
OldBB->removePredecessor(PrevBB, /* KeepOneInputPHIs = */ true);
Expand Down Expand Up @@ -1341,17 +1340,20 @@ struct TransformDFA {
// updated yet
if (!isa<SwitchInst>(LastBlock->getTerminator()))
return;
SwitchInst *Switch = cast<SwitchInst>(LastBlock->getTerminator());
BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
assert(BB->getTerminator() == SwitchPaths->getSwitchInst() &&
"Original last block must contain the threaded switch");
BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState);

std::vector<DominatorTree::UpdateType> DTUpdates;
SmallPtrSet<BasicBlock *, 4> SuccSet;
for (BasicBlock *Succ : successors(LastBlock)) {
if (Succ != NextCase && SuccSet.insert(Succ).second)
for (BasicBlock *Succ : successors(LastBlock))
if (SuccSet.insert(Succ).second && Succ != NextCase)
DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ});
}

Switch->eraseFromParent();
if (!SuccSet.count(NextCase))
DTUpdates.push_back({DominatorTree::Insert, LastBlock, NextCase});

LastBlock->getTerminator()->eraseFromParent();
BranchInst::Create(NextCase, LastBlock);

DTU->applyUpdates(DTUpdates);
Expand Down
10 changes: 5 additions & 5 deletions llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ define i32 @test2(i32 %num) {
; CHECK: for.body.jt3:
; CHECK-NEXT: [[COUNT_JT3:%.*]] = phi i32 [ [[INC_JT3:%.*]], [[FOR_INC_JT3:%.*]] ]
; CHECK-NEXT: [[STATE_JT3:%.*]] = phi i32 [ [[STATE_NEXT_JT3:%.*]], [[FOR_INC_JT3]] ]
; CHECK-NEXT: br label [[FOR_INC]]
; CHECK-NEXT: br label [[FOR_INC_JT1]]
; CHECK: case1:
; CHECK-NEXT: [[COUNT6:%.*]] = phi i32 [ [[COUNT_JT1]], [[FOR_BODY_JT1:%.*]] ], [ [[COUNT]], [[FOR_BODY]] ]
; CHECK-NEXT: [[CMP_C1:%.*]] = icmp slt i32 [[COUNT6]], 50
Expand Down Expand Up @@ -156,8 +156,8 @@ define i32 @test2(i32 %num) {
; CHECK-NEXT: [[DOTSI_UNFOLD_PHI4_JT2:%.*]] = phi i32 [ 2, [[STATE1_1_SI_UNFOLD_TRUE:%.*]] ]
; CHECK-NEXT: br label [[FOR_INC_JT2]]
; CHECK: for.inc:
; CHECK-NEXT: [[COUNT5:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ]
; CHECK-NEXT: [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ 1, [[FOR_BODY_JT3]] ]
; CHECK-NEXT: [[COUNT5:%.*]] = phi i32 [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ]
; CHECK-NEXT: [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ]
; CHECK-NEXT: [[INC]] = add nsw i32 [[COUNT5]], 1
; CHECK-NEXT: [[CMP_EXIT:%.*]] = icmp slt i32 [[INC]], [[NUM:%.*]]
; CHECK-NEXT: br i1 [[CMP_EXIT]], label [[FOR_BODY]], label [[FOR_END:%.*]]
Expand All @@ -167,8 +167,8 @@ define i32 @test2(i32 %num) {
; CHECK-NEXT: [[CMP_EXIT_JT2:%.*]] = icmp slt i32 [[INC_JT2]], [[NUM]]
; CHECK-NEXT: br i1 [[CMP_EXIT_JT2]], label [[FOR_BODY_JT2:%.*]], label [[FOR_END]]
; CHECK: for.inc.jt1:
; CHECK-NEXT: [[COUNT7:%.*]] = phi i32 [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ]
; CHECK-NEXT: [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ]
; CHECK-NEXT: [[COUNT7:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ]
; CHECK-NEXT: [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ 1, [[FOR_BODY_JT3]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ]
; CHECK-NEXT: [[INC_JT1]] = add nsw i32 [[COUNT7]], 1
; CHECK-NEXT: [[CMP_EXIT_JT1:%.*]] = icmp slt i32 [[INC_JT1]], [[NUM]]
; CHECK-NEXT: br i1 [[CMP_EXIT_JT1]], label [[FOR_BODY_JT1]], label [[FOR_END]]
Expand Down