Skip to content

Commit

Permalink
[CallSiteSplitting] properly split musttail calls
Browse files Browse the repository at this point in the history
Summary:
`musttail` calls can't be naively splitted. The split blocks must
include not only the call instruction itself, but also (optional)
`bitcast` and `return` instructions that follow it.

Clone `bitcast` and `ret`, place them into the split blocks, and
remove the tail block when done.

Reviewers: junbuml, mcrosier, davidxl, davide, fhahn

Reviewed By: fhahn

Subscribers: JDevlieghere, llvm-commits

Differential Revision: https://reviews.llvm.org/D43729

llvm-svn: 326666
  • Loading branch information
indutny committed Mar 3, 2018
1 parent 1a3901c commit f9e09c1
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 4 deletions.
76 changes: 72 additions & 4 deletions llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
Expand Up @@ -209,8 +209,46 @@ static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) {
return CallSiteBB->canSplitPredecessors();
}

/// Return true if the CS is split into its new predecessors.
static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before,
Value *V) {
Instruction *Copy = I->clone();
Copy->setName(I->getName());
Copy->insertBefore(Before);
if (V)
Copy->setOperand(0, V);
return Copy;
}

/// Copy mandatory `musttail` return sequence that follows original `CI`, and
/// link it up to `NewCI` value instead:
///
/// * (optional) `bitcast NewCI to ...`
/// * `ret bitcast or NewCI`
///
/// Insert this sequence right before `SplitBB`'s terminator, which will be
/// cleaned up later in `splitCallSite` below.
static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI,
Instruction *NewCI) {
bool IsVoid = SplitBB->getParent()->getReturnType()->isVoidTy();
auto II = std::next(CI->getIterator());

BitCastInst* BCI = dyn_cast<BitCastInst>(&*II);
if (BCI)
++II;

ReturnInst* RI = dyn_cast<ReturnInst>(&*II);
assert(RI && "`musttail` call must be followed by `ret` instruction");

TerminatorInst *TI = SplitBB->getTerminator();
Value *V = NewCI;
if (BCI)
V = cloneInstForMustTail(BCI, TI, V);
cloneInstForMustTail(RI, TI, IsVoid ? nullptr : V);

// FIXME: remove TI here, `DuplicateInstructionsInSplitBetween` has a bug
// that prevents doing this now.
}

/// For each (predecessor, conditions from predecessors) pair, it will split the
/// basic block containing the call site, hook it up to the predecessor and
/// replace the call instruction with new call instructions, which contain
Expand Down Expand Up @@ -257,9 +295,14 @@ static void splitCallSite(
const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds) {
Instruction *Instr = CS.getInstruction();
BasicBlock *TailBB = Instr->getParent();
bool IsMustTailCall = CS.isMustTailCall();

PHINode *CallPN = nullptr;
if (Instr->getNumUses())

// `musttail` calls must be followed by optional `bitcast`, and `ret`. The
// split blocks will be terminated right after that so there're no users for
// this phi in a `TailBB`.
if (!IsMustTailCall && Instr->getNumUses())
CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call");

DEBUG(dbgs() << "split call-site : " << *Instr << " into \n");
Expand Down Expand Up @@ -293,6 +336,23 @@ static void splitCallSite(
<< "\n");
if (CallPN)
CallPN->addIncoming(NewCI, SplitBlock);

// Clone and place bitcast and return instructions before `TI`
if (IsMustTailCall)
copyMustTailReturn(SplitBlock, Instr, NewCI);
}

NumCallSiteSplit++;

// FIXME: remove TI in `copyMustTailReturn`
if (IsMustTailCall) {
// Remove superfluous `br` terminators from the end of the Split blocks
for (BasicBlock *SplitBlock : predecessors(TailBB))
SplitBlock->getTerminator()->eraseFromParent();

// Erase the tail block once done with musttail patching
TailBB->eraseFromParent();
return;
}

auto *OriginalBegin = &*TailBB->begin();
Expand Down Expand Up @@ -329,8 +389,6 @@ static void splitCallSite(
if (CurrentI == OriginalBegin)
break;
}

NumCallSiteSplit++;
}

// Return true if the call-site has an argument which is a PHI with only
Expand Down Expand Up @@ -415,7 +473,17 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI,
Function *Callee = CS.getCalledFunction();
if (!Callee || Callee->isDeclaration())
continue;

// Successful musttail call-site splits result in erased CI and erased BB.
// Check if such path is possible before attempting the splitting.
bool IsMustTail = CS.isMustTailCall();

Changed |= tryToSplitCallSite(CS, TTI);

// There're no interesting instructions after this. The call site
// itself might have been erased on splitting.
if (IsMustTail)
break;
}
}
return Changed;
Expand Down
75 changes: 75 additions & 0 deletions llvm/test/Transforms/CallSiteSplitting/musttail.ll
@@ -0,0 +1,75 @@
; RUN: opt < %s -callsite-splitting -S | FileCheck %s

;CHECK-LABEL: @caller
;CHECK-LABEL: Top.split:
;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b)
;CHECK: %cb2 = bitcast i8* %ca1 to i8*
;CHECK: ret i8* %cb2
;CHECK-LABEL: TBB.split
;CHECK: %ca3 = musttail call i8* @callee(i8* nonnull %a, i8* null)
;CHECK: %cb4 = bitcast i8* %ca3 to i8*
;CHECK: ret i8* %cb4
define i8* @caller(i8* %a, i8* %b) {
Top:
%c = icmp eq i8* %a, null
br i1 %c, label %Tail, label %TBB
TBB:
%c2 = icmp eq i8* %b, null
br i1 %c2, label %Tail, label %End
Tail:
%ca = musttail call i8* @callee(i8* %a, i8* %b)
%cb = bitcast i8* %ca to i8*
ret i8* %cb
End:
ret i8* null
}

define i8* @callee(i8* %a, i8* %b) noinline {
ret i8* %a
}

;CHECK-LABEL: @no_cast_caller
;CHECK-LABEL: Top.split:
;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b)
;CHECK: ret i8* %ca1
;CHECK-LABEL: TBB.split
;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null)
;CHECK: ret i8* %ca2
define i8* @no_cast_caller(i8* %a, i8* %b) {
Top:
%c = icmp eq i8* %a, null
br i1 %c, label %Tail, label %TBB
TBB:
%c2 = icmp eq i8* %b, null
br i1 %c2, label %Tail, label %End
Tail:
%ca = musttail call i8* @callee(i8* %a, i8* %b)
ret i8* %ca
End:
ret i8* null
}

;CHECK-LABEL: @void_caller
;CHECK-LABEL: Top.split:
;CHECK: musttail call void @void_callee(i8* null, i8* %b)
;CHECK: ret void
;CHECK-LABEL: TBB.split
;CHECK: musttail call void @void_callee(i8* nonnull %a, i8* null)
;CHECK: ret void
define void @void_caller(i8* %a, i8* %b) {
Top:
%c = icmp eq i8* %a, null
br i1 %c, label %Tail, label %TBB
TBB:
%c2 = icmp eq i8* %b, null
br i1 %c2, label %Tail, label %End
Tail:
musttail call void @void_callee(i8* %a, i8* %b)
ret void
End:
ret void
}

define void @void_callee(i8* %a, i8* %b) noinline {
ret void
}

0 comments on commit f9e09c1

Please sign in to comment.