Skip to content

Commit

Permalink
Add a llvm.coro.end.async intrinsic
Browse files Browse the repository at this point in the history
The llvm.coro.end.async intrinsic allows to specify a function that is
to be called as the last action before returning. This function will be
inlined after coroutine splitting.

This function can contain a 'musttail' call to allow for guaranteed tail
calling as the last action.

Differential Revision: https://reviews.llvm.org/D93568
  • Loading branch information
aschwaighofer committed Dec 22, 2020
1 parent ae8f4b2 commit 333108e
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 33 deletions.
42 changes: 42 additions & 0 deletions llvm/docs/Coroutines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,48 @@ The following table summarizes the handling of `coro.end`_ intrinsic.
| | Landingpad | nothing | nothing |
+------------+-------------+-------------------+-------------------------------+


'llvm.coro.end.async' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::

declare i1 @llvm.coro.end.async(i8* <handle>, i1 <unwind>, ...)

Overview:
"""""""""

The '``llvm.coro.end.async``' marks the point where execution of the resume part
of the coroutine should end and control should return to the caller. As part of
its variable tail arguments this instruction allows to specify a function and
the function's arguments that are to be tail called as the last action before
returning.


Arguments:
""""""""""

The first argument should refer to the coroutine handle of the enclosing
coroutine. A frontend is allowed to supply null as the first parameter, in this
case `coro-early` pass will replace the null with an appropriate coroutine
handle value.

The second argument should be `true` if this coro.end is in the block that is
part of the unwind sequence leaving the coroutine body due to an exception and
`false` otherwise.

The third argument if present should specify a function to be called.

If the third argument is present, the remaining arguments are the arguments to
the function call.

.. code-block:: llvm
call i1 (i8*, i1, ...) @llvm.coro.end.async(
i8* %hdl, i1 0,
void (i8*, %async.task*, %async.actor*)* @must_tail_call_return,
i8* %ctxt, %async.task* %task, %async.actor* %actor)
unreachable
.. _coro.suspend:
.. _suspend points:

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,8 @@ def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
ReadOnly<ArgIndex<1>>,
NoCapture<ArgIndex<1>>]>;
def int_coro_end : Intrinsic<[llvm_i1_ty], [llvm_ptr_ty, llvm_i1_ty], []>;
def int_coro_end_async
: Intrinsic<[llvm_i1_ty], [llvm_ptr_ty, llvm_i1_ty, llvm_vararg_ty], []>;

def int_coro_frame : Intrinsic<[llvm_ptr_ty], [], [IntrNoMem]>;
def int_coro_noop : Intrinsic<[llvm_ptr_ty], [], [IntrNoMem]>;
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Transforms/Coroutines/CoroEarly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,11 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) {
if (cast<CoroSuspendInst>(&I)->isFinal())
CB->setCannotDuplicate();
break;
case Intrinsic::coro_end_async:
case Intrinsic::coro_end:
// Make sure that fallthrough coro.end is not duplicated as CoroSplit
// pass expects that there is at most one fallthrough coro.end.
if (cast<CoroEndInst>(&I)->isFallthrough())
if (cast<AnyCoroEndInst>(&I)->isFallthrough())
CB->setCannotDuplicate();
break;
case Intrinsic::coro_noop:
Expand Down Expand Up @@ -219,8 +220,8 @@ static bool declaresCoroEarlyIntrinsics(const Module &M) {
return coro::declaresIntrinsics(
M, {"llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.retcon.once",
"llvm.coro.destroy", "llvm.coro.done", "llvm.coro.end",
"llvm.coro.noop", "llvm.coro.free", "llvm.coro.promise",
"llvm.coro.resume", "llvm.coro.suspend"});
"llvm.coro.end.async", "llvm.coro.noop", "llvm.coro.free",
"llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"});
}

PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) {
Expand Down
19 changes: 18 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2170,9 +2170,26 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
}

// Put CoroEnds into their own blocks.
for (CoroEndInst *CE : Shape.CoroEnds)
for (AnyCoroEndInst *CE : Shape.CoroEnds) {
splitAround(CE, "CoroEnd");

// Emit the musttail call function in a new block before the CoroEnd.
// We do this here so that the right suspend crossing info is computed for
// the uses of the musttail call function call. (Arguments to the coro.end
// instructions would be ignored)
if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(CE)) {
auto *MustTailCallFn = AsyncEnd->getMustTailCallFunction();
if (!MustTailCallFn)
continue;
IRBuilder<> Builder(AsyncEnd);
SmallVector<Value *, 8> Args(AsyncEnd->args());
auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
Arguments, Builder);
splitAround(Call, "MustTailCall.Before.CoroEnd");
}
}

// Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will
// never has its definition separated from the PHI by the suspend point.
rewritePHIs(F);
Expand Down
40 changes: 38 additions & 2 deletions llvm/lib/Transforms/Coroutines/CoroInstr.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,7 @@ class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst {
}
};

/// This represents the llvm.coro.end instruction.
class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst {
class LLVM_LIBRARY_VISIBILITY AnyCoroEndInst : public IntrinsicInst {
enum { FrameArg, UnwindArg };

public:
Expand All @@ -587,6 +586,19 @@ class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst {
return cast<Constant>(getArgOperand(UnwindArg))->isOneValue();
}

// Methods to support type inquiry through isa, cast, and dyn_cast:
static bool classof(const IntrinsicInst *I) {
auto ID = I->getIntrinsicID();
return ID == Intrinsic::coro_end || ID == Intrinsic::coro_end_async;
}
static bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This represents the llvm.coro.end instruction.
class LLVM_LIBRARY_VISIBILITY CoroEndInst : public AnyCoroEndInst {
public:
// Methods to support type inquiry through isa, cast, and dyn_cast:
static bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_end;
Expand All @@ -596,6 +608,30 @@ class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst {
}
};

/// This represents the llvm.coro.end instruction.
class LLVM_LIBRARY_VISIBILITY CoroAsyncEndInst : public AnyCoroEndInst {
enum { FrameArg, UnwindArg, MustTailCallFuncArg };

public:
void checkWellFormed() const;

Function *getMustTailCallFunction() const {
if (getNumArgOperands() < 3)
return nullptr;

return cast<Function>(
getArgOperand(MustTailCallFuncArg)->stripPointerCasts());
}

// Methods to support type inquiry through isa, cast, and dyn_cast:
static bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_end_async;
}
static bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This represents the llvm.coro.alloca.alloc instruction.
class LLVM_LIBRARY_VISIBILITY CoroAllocaAllocInst : public IntrinsicInst {
enum { SizeArg, AlignArg };
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ enum class ABI {
// values used during CoroSplit pass.
struct LLVM_LIBRARY_VISIBILITY Shape {
CoroBeginInst *CoroBegin;
SmallVector<CoroEndInst *, 4> CoroEnds;
SmallVector<AnyCoroEndInst *, 4> CoroEnds;
SmallVector<CoroSizeInst *, 2> CoroSizes;
SmallVector<AnyCoroSuspendInst *, 4> CoroSuspends;
SmallVector<CallInst*, 2> SwiftErrorOps;
Expand Down Expand Up @@ -270,6 +270,8 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
};

void buildCoroutineFrame(Function &F, Shape &Shape);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
} // End namespace llvm

Expand Down
100 changes: 79 additions & 21 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,53 @@ static void maybeFreeRetconStorage(IRBuilder<> &Builder,
Shape.emitDealloc(Builder, FramePtr, CG);
}

/// Replace an llvm.coro.end.async.
/// Will inline the must tail call function call if there is one.
/// \returns true if cleanup of the coro.end block is needed, false otherwise.
static bool replaceCoroEndAsync(AnyCoroEndInst *End) {
IRBuilder<> Builder(End);

auto *EndAsync = dyn_cast<CoroAsyncEndInst>(End);
if (!EndAsync) {
Builder.CreateRetVoid();
return true /*needs cleanup of coro.end block*/;
}

auto *MustTailCallFunc = EndAsync->getMustTailCallFunction();
if (!MustTailCallFunc) {
Builder.CreateRetVoid();
return true /*needs cleanup of coro.end block*/;
}

// Move the must tail call from the predecessor block into the end block.
auto *CoroEndBlock = End->getParent();
auto *MustTailCallFuncBlock = CoroEndBlock->getSinglePredecessor();
assert(MustTailCallFuncBlock && "Must have a single predecessor block");
auto It = MustTailCallFuncBlock->getTerminator()->getIterator();
auto *MustTailCall = cast<CallInst>(&*std::prev(It));
CoroEndBlock->getInstList().splice(
End->getIterator(), MustTailCallFuncBlock->getInstList(), MustTailCall);

// Insert the return instruction.
Builder.SetInsertPoint(End);
Builder.CreateRetVoid();
InlineFunctionInfo FnInfo;

// Remove the rest of the block, by splitting it into an unreachable block.
auto *BB = End->getParent();
BB->splitBasicBlock(End);
BB->getTerminator()->eraseFromParent();

auto InlineRes = InlineFunction(*MustTailCall, FnInfo);
assert(InlineRes.isSuccess() && "Expected inlining to succeed");
(void)InlineRes;

// We have cleaned up the coro.end block above.
return false;
}

/// Replace a non-unwind call to llvm.coro.end.
static void replaceFallthroughCoroEnd(CoroEndInst *End,
static void replaceFallthroughCoroEnd(AnyCoroEndInst *End,
const coro::Shape &Shape, Value *FramePtr,
bool InResume, CallGraph *CG) {
// Start inserting right before the coro.end.
Expand All @@ -192,9 +237,12 @@ static void replaceFallthroughCoroEnd(CoroEndInst *End,
break;

// In async lowering this returns.
case coro::ABI::Async:
Builder.CreateRetVoid();
case coro::ABI::Async: {
bool CoroEndBlockNeedsCleanup = replaceCoroEndAsync(End);
if (!CoroEndBlockNeedsCleanup)
return;
break;
}

// In unique continuation lowering, the continuations always return void.
// But we may have implicitly allocated storage.
Expand Down Expand Up @@ -229,8 +277,9 @@ static void replaceFallthroughCoroEnd(CoroEndInst *End,
}

/// Replace an unwind call to llvm.coro.end.
static void replaceUnwindCoroEnd(CoroEndInst *End, const coro::Shape &Shape,
Value *FramePtr, bool InResume, CallGraph *CG){
static void replaceUnwindCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape,
Value *FramePtr, bool InResume,
CallGraph *CG) {
IRBuilder<> Builder(End);

switch (Shape.ABI) {
Expand Down Expand Up @@ -258,7 +307,7 @@ static void replaceUnwindCoroEnd(CoroEndInst *End, const coro::Shape &Shape,
}
}

static void replaceCoroEnd(CoroEndInst *End, const coro::Shape &Shape,
static void replaceCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape,
Value *FramePtr, bool InResume, CallGraph *CG) {
if (End->isUnwind())
replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG);
Expand Down Expand Up @@ -511,10 +560,10 @@ void CoroCloner::replaceCoroSuspends() {
}

void CoroCloner::replaceCoroEnds() {
for (CoroEndInst *CE : Shape.CoroEnds) {
for (AnyCoroEndInst *CE : Shape.CoroEnds) {
// We use a null call graph because there's no call graph node for
// the cloned function yet. We'll just be rebuilding that later.
auto NewCE = cast<CoroEndInst>(VMap[CE]);
auto *NewCE = cast<AnyCoroEndInst>(VMap[CE]);
replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr);
}
}
Expand Down Expand Up @@ -1385,6 +1434,23 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
}
}

CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
ArrayRef<Value *> Arguments,
IRBuilder<> &Builder) {
auto *FnTy =
cast<FunctionType>(MustTailCallFn->getType()->getPointerElementType());
// Coerce the arguments, llvm optimizations seem to ignore the types in
// vaarg functions and throws away casts in optimized mode.
SmallVector<Value *, 8> CallArgs;
coerceArguments(Builder, FnTy, Arguments, CallArgs);

auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
TailCall->setTailCallKind(CallInst::TCK_MustTail);
TailCall->setDebugLoc(Loc);
TailCall->setCallingConv(MustTailCallFn->getCallingConv());
return TailCall;
}

static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones) {
assert(Shape.ABI == coro::ABI::Async);
Expand Down Expand Up @@ -1443,18 +1509,10 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,

// Insert the call to the tail call function and inline it.
auto *Fn = Suspend->getMustTailCallFunction();
auto DbgLoc = Suspend->getDebugLoc();
SmallVector<Value *, 8> Args(Suspend->operand_values());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(3).drop_back(1);
auto FnTy = cast<FunctionType>(Fn->getType()->getPointerElementType());
// Coerce the arguments, llvm optimizations seem to ignore the types in
// vaarg functions and throws away casts in optimized mode.
SmallVector<Value *, 8> CallArgs;
coerceArguments(Builder, FnTy, FnArgs, CallArgs);
auto *TailCall = Builder.CreateCall(FnTy, Fn, CallArgs);
TailCall->setDebugLoc(DbgLoc);
TailCall->setTailCall();
TailCall->setCallingConv(Fn->getCallingConv());
SmallVector<Value *, 8> Args(Suspend->args());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(3);
auto *TailCall =
coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
Builder.CreateRetVoid();
InlineFunctionInfo FnInfo;
auto InlineRes = InlineFunction(*TailCall, FnInfo);
Expand Down Expand Up @@ -1683,7 +1741,7 @@ static void updateCallGraphAfterCoroutineSplit(
if (!Shape.CoroBegin)
return;

for (llvm::CoroEndInst *End : Shape.CoroEnds) {
for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) {
auto &Context = End->getContext();
End->replaceAllUsesWith(ConstantInt::getFalse(Context));
End->eraseFromParent();
Expand Down
Loading

0 comments on commit 333108e

Please sign in to comment.