Skip to content

Commit

Permalink
[SimpleLoopUnswitch] unswitch selects
Browse files Browse the repository at this point in the history
The old LoopUnswitch pass unswitched selects, but the changes were never
ported to the new SimpleLoopUnswitch.

We unswitch by turning:

```
S = select %cond, %a, %b
```

into:

```
head:
br %cond, label %then, label %tail

then:
br label %tail

tail:
S = phi [ %a, %then ], [ %b, %head ]
```

Unswitch selects are always nontrivial, since the successors do not exit
the loop and the loop body always needs to be cloned.

Differential Revision: https://reviews.llvm.org/D138526

Co-authored-by: Sergey Kachkov <sergey.kachkov@syntacore.com>
  • Loading branch information
caojoshua and skachkov-sc committed Apr 30, 2023
1 parent 831c221 commit e479ed9
Show file tree
Hide file tree
Showing 4 changed files with 477 additions and 83 deletions.
100 changes: 85 additions & 15 deletions llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
Expand Up @@ -19,6 +19,7 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GuardUtils.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
Expand Down Expand Up @@ -74,6 +75,7 @@ using namespace llvm::PatternMatch;

STATISTIC(NumBranches, "Number of branches unswitched");
STATISTIC(NumSwitches, "Number of switches unswitched");
STATISTIC(NumSelects, "Number of selects turned into branches for unswitching");
STATISTIC(NumGuards, "Number of guards turned into branches for unswitching");
STATISTIC(NumTrivial, "Number of unswitches that are trivial");
STATISTIC(
Expand Down Expand Up @@ -2642,6 +2644,61 @@ static InstructionCost computeDomSubtreeCost(
return Cost;
}

/// Turns a select instruction into implicit control flow branch,
/// making the following replacement:
///
/// head:
/// --code before select--
/// select %cond, %trueval, %falseval
/// --code after select--
///
/// into
///
/// head:
/// --code before select--
/// br i1 %cond, label %then, label %tail
///
/// then:
/// br %tail
///
/// tail:
/// phi [ %trueval, %then ], [ %falseval, %head]
/// unreachable
///
/// It also makes all relevant DT and LI updates, so that all structures are in
/// valid state after this transform.
static BranchInst *turnSelectIntoBranch(SelectInst *SI, DominatorTree &DT,
LoopInfo &LI, MemorySSAUpdater *MSSAU,
AssumptionCache *AC) {
LLVM_DEBUG(dbgs() << "Turning " << *SI << " into a branch.\n");
BasicBlock *HeadBB = SI->getParent();

Value *Cond = SI->getCondition();
if (!isGuaranteedNotToBeUndefOrPoison(Cond, AC, SI, &DT))
Cond = new FreezeInst(Cond, Cond->getName() + ".fr", SI);
DomTreeUpdater DTU =
DomTreeUpdater(DT, DomTreeUpdater::UpdateStrategy::Eager);
SplitBlockAndInsertIfThen(SI->getCondition(), SI, false,
SI->getMetadata(LLVMContext::MD_prof), &DTU, &LI);
auto *CondBr = cast<BranchInst>(HeadBB->getTerminator());
BasicBlock *ThenBB = CondBr->getSuccessor(0),
*TailBB = CondBr->getSuccessor(1);
if (MSSAU)
MSSAU->moveAllAfterSpliceBlocks(HeadBB, TailBB, SI);

PHINode *Phi = PHINode::Create(SI->getType(), 2, "unswitched.select", SI);
Phi->addIncoming(SI->getTrueValue(), ThenBB);
Phi->addIncoming(SI->getFalseValue(), HeadBB);
SI->replaceAllUsesWith(Phi);
SI->eraseFromParent();

if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();

++NumSelects;
return CondBr;
}

/// Turns a llvm.experimental.guard intrinsic into implicit control flow branch,
/// making the following replacement:
///
Expand Down Expand Up @@ -2749,9 +2806,10 @@ static int CalculateUnswitchCostMultiplier(
const BasicBlock *CondBlock = TI.getParent();
if (DT.dominates(CondBlock, Latch) &&
(isGuard(&TI) ||
llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) {
return L.contains(SuccBB);
}) <= 1)) {
(TI.isTerminator() &&
llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) {
return L.contains(SuccBB);
}) <= 1))) {
NumCostMultiplierSkipped++;
return 1;
}
Expand All @@ -2760,12 +2818,17 @@ static int CalculateUnswitchCostMultiplier(
int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size()
: std::distance(LI.begin(), LI.end()));
// Count amount of clones that all the candidates might cause during
// unswitching. Branch/guard counts as 1, switch counts as log2 of its cases.
// unswitching. Branch/guard/select counts as 1, switch counts as log2 of its
// cases.
int UnswitchedClones = 0;
for (const auto &Candidate : UnswitchCandidates) {
const Instruction *CI = Candidate.TI;
const BasicBlock *CondBlock = CI->getParent();
bool SkipExitingSuccessors = DT.dominates(CondBlock, Latch);
if (isa<SelectInst>(CI)) {
UnswitchedClones++;
continue;
}
if (isGuard(CI)) {
if (!SkipExitingSuccessors)
UnswitchedClones++;
Expand Down Expand Up @@ -2828,15 +2891,19 @@ static bool collectUnswitchCandidates(
if (LI.getLoopFor(BB) != &L)
continue;

if (CollectGuards)
for (auto &I : *BB)
if (isGuard(&I)) {
auto *Cond =
skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0));
// TODO: Support AND, OR conditions and partial unswitching.
if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond))
UnswitchCandidates.push_back({&I, {Cond}});
}
for (auto &I : *BB) {
if (auto *SI = dyn_cast<SelectInst>(&I)) {
auto *Cond = SI->getCondition();
if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond))
UnswitchCandidates.push_back({&I, {Cond}});
} else if (CollectGuards && isGuard(&I)) {
auto *Cond =
skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0));
// TODO: Support AND, OR conditions and partial unswitching.
if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond))
UnswitchCandidates.push_back({&I, {Cond}});
}
}

if (auto *SI = dyn_cast<SwitchInst>(BB->getTerminator())) {
// We can only consider fully loop-invariant switch conditions as we need
Expand Down Expand Up @@ -3338,7 +3405,8 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate(
// loop. This is computing the new cost of unswitching a condition.
// Note that guards always have 2 unique successors that are implicit and
// will be materialized if we decide to unswitch it.
int SuccessorsCount = isGuard(&TI) ? 2 : Visited.size();
int SuccessorsCount =
isGuard(&TI) || isa<SelectInst>(TI) ? 2 : Visited.size();
assert(SuccessorsCount > 1 &&
"Cannot unswitch a condition without multiple distinct successors!");
return (LoopCost - Cost) * (SuccessorsCount - 1);
Expand Down Expand Up @@ -3425,7 +3493,9 @@ static bool unswitchBestCondition(
PartialIVInfo.InstToDuplicate.clear();

// If the best candidate is a guard, turn it into a branch.
if (isGuard(Best.TI))
if (auto *SI = dyn_cast<SelectInst>(Best.TI))
Best.TI = turnSelectIntoBranch(SI, DT, LI, MSSAU, &AC);
else if (isGuard(Best.TI))
Best.TI =
turnGuardIntoBranch(cast<IntrinsicInst>(Best.TI), L, DT, LI, MSSAU);

Expand Down
Expand Up @@ -2332,21 +2332,26 @@ exit:
define i32 @test_partial_unswitch_all_conds_guaranteed_non_poison(i1 noundef %c.1, i1 noundef %c.2) {
; CHECK-LABEL: @test_partial_unswitch_all_conds_guaranteed_non_poison(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = and i1 [[C_1:%.*]], [[C_2:%.*]]
; CHECK-NEXT: br i1 [[TMP0]], label [[ENTRY_SPLIT:%.*]], label [[ENTRY_SPLIT_US:%.*]]
; CHECK-NEXT: br i1 [[C_1:%.*]], label [[ENTRY_SPLIT_US:%.*]], label [[ENTRY_SPLIT:%.*]]
; CHECK: entry.split.us:
; CHECK-NEXT: br label [[LOOP_US:%.*]]
; CHECK: loop.us:
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @a()
; CHECK-NEXT: br label [[EXIT_SPLIT_US:%.*]]
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @a()
; CHECK-NEXT: br label [[TMP1:%.*]]
; CHECK: 1:
; CHECK-NEXT: br label [[TMP2:%.*]]
; CHECK: 2:
; CHECK-NEXT: [[UNSWITCHED_SELECT_US:%.*]] = phi i1 [ [[C_2:%.*]], [[TMP1]] ]
; CHECK-NEXT: br i1 [[UNSWITCHED_SELECT_US]], label [[LOOP_US]], label [[EXIT_SPLIT_US:%.*]]
; CHECK: exit.split.us:
; CHECK-NEXT: br label [[EXIT:%.*]]
; CHECK: entry.split:
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[TMP2:%.*]] = call i32 @a()
; CHECK-NEXT: [[SEL:%.*]] = select i1 true, i1 true, i1 false
; CHECK-NEXT: br i1 [[SEL]], label [[LOOP]], label [[EXIT_SPLIT:%.*]]
; CHECK-NEXT: [[TMP3:%.*]] = call i32 @a()
; CHECK-NEXT: br label [[TMP4:%.*]]
; CHECK: 4:
; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT_SPLIT:%.*]]
; CHECK: exit.split:
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
Expand Down

0 comments on commit e479ed9

Please sign in to comment.