158 changes: 131 additions & 27 deletions llvm/lib/Transforms/Coroutines/CoroElide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"

using namespace llvm;

Expand All @@ -39,11 +40,29 @@ struct CoroElide : FunctionPass {

bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AAResultsWrapperPass>();
AU.setPreservesCFG();
}
};
}

char CoroElide::ID = 0;
INITIALIZE_PASS_BEGIN(
CoroElide, "coro-elide",
"Coroutine frame allocation elision and indirect calls replacement", false,
false)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
INITIALIZE_PASS_END(
CoroElide, "coro-elide",
"Coroutine frame allocation elision and indirect calls replacement", false,
false)

Pass *llvm::createCoroElidePass() { return new CoroElide(); }

//===----------------------------------------------------------------------===//
// Implementation
//===----------------------------------------------------------------------===//

// Go through the list of coro.subfn.addr intrinsics and replace them with the
// provided constant.
static void replaceWithConstant(Constant *Value,
Expand All @@ -68,24 +87,103 @@ static void replaceWithConstant(Constant *Value,
replaceAndRecursivelySimplify(I, Value);
}

// See if any operand of the call instruction references the coroutine frame.
static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) {
for (Value *Op : CI->operand_values())
if (AA.alias(Op, Frame) != NoAlias)
return true;
return false;
}

// Look for any tail calls referencing the coroutine frame and remove tail
// attribute from them, since now coroutine frame resides on the stack and tail
// call implies that the function does not references anything on the stack.
static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {
Function &F = *Frame->getFunction();
MemoryLocation Mem(Frame);
for (Instruction &I : instructions(F))
if (auto *Call = dyn_cast<CallInst>(&I))
if (Call->isTailCall() && operandReferences(Call, Frame, AA)) {
// FIXME: If we ever hit this check. Evaluate whether it is more
// appropriate to retain musttail and allow the code to compile.
if (Call->isMustTailCall())
report_fatal_error("Call referring to the coroutine frame cannot be "
"marked as musttail");
Call->setTailCall(false);
}
}

// Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type.
static Type *getFrameType(Function *Resume) {
auto *ArgType = Resume->getArgumentList().front().getType();
return cast<PointerType>(ArgType)->getElementType();
}

// Finds first non alloca instruction in the entry block of a function.
static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
for (Instruction &I : F->getEntryBlock())
if (!isa<AllocaInst>(&I))
return &I;
llvm_unreachable("no terminator in the entry block");
}

// To elide heap allocations we need to suppress code blocks guarded by
// llvm.coro.alloc and llvm.coro.free instructions.
static void elideHeapAllocations(CoroBeginInst *CoroBegin, Type *FrameTy,
CoroAllocInst *AllocInst, AAResults &AA) {
LLVMContext &C = CoroBegin->getContext();
auto *InsertPt = getFirstNonAllocaInTheEntryBlock(CoroBegin->getFunction());

// FIXME: Design how to transmit alignment information for every alloca that
// is spilled into the coroutine frame and recreate the alignment information
// here. Possibly we will need to do a mini SROA here and break the coroutine
// frame into individual AllocaInst recreating the original alignment.
auto *Frame = new AllocaInst(FrameTy, "", InsertPt);
auto *FrameVoidPtr =
new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt);

// Replacing llvm.coro.alloc with non-null value will suppress dynamic
// allocation as it is expected for the frontend to generate the code that
// looks like:
// mem = coro.alloc();
// if (!mem) mem = malloc(coro.size());
// coro.begin(mem, ...)
AllocInst->replaceAllUsesWith(FrameVoidPtr);
AllocInst->eraseFromParent();

// To suppress deallocation code, we replace all llvm.coro.free intrinsics
// associated with this coro.begin with null constant.
auto *NullPtr = ConstantPointerNull::get(Type::getInt8PtrTy(C));
coro::replaceAllCoroFrees(CoroBegin, NullPtr);
CoroBegin->lowerTo(FrameVoidPtr);

// Since now coroutine frame lives on the stack we need to make sure that
// any tail call referencing it, must be made non-tail call.
removeTailCallAttribute(Frame, AA);
}

// See if there are any coro.subfn.addr intrinsics directly referencing
// the coro.begin. If found, replace them with an appropriate coroutine
// subfunction associated with that coro.begin.
static bool replaceIndirectCalls(CoroBeginInst *CoroBegin) {
static bool replaceIndirectCalls(CoroBeginInst *CoroBegin, AAResults &AA) {
SmallVector<CoroSubFnInst *, 8> ResumeAddr;
SmallVector<CoroSubFnInst *, 8> DestroyAddr;

for (User *U : CoroBegin->users()) {
if (auto *II = dyn_cast<CoroSubFnInst>(U)) {
switch (II->getIndex()) {
case CoroSubFnInst::ResumeIndex:
ResumeAddr.push_back(II);
break;
case CoroSubFnInst::DestroyIndex:
DestroyAddr.push_back(II);
break;
default:
llvm_unreachable("unexpected coro.subfn.addr constant");
for (User *CF : CoroBegin->users()) {
assert(isa<CoroFrameInst>(CF) &&
"CoroBegin can be only used by coro.frame instructions");
for (User *U : CF->users()) {
if (auto *II = dyn_cast<CoroSubFnInst>(U)) {
switch (II->getIndex()) {
case CoroSubFnInst::ResumeIndex:
ResumeAddr.push_back(II);
break;
case CoroSubFnInst::DestroyIndex:
DestroyAddr.push_back(II);
break;
default:
llvm_unreachable("unexpected coro.subfn.addr constant");
}
}
}
}
Expand All @@ -99,11 +197,28 @@ static bool replaceIndirectCalls(CoroBeginInst *CoroBegin) {
"of coroutine subfunctions");
auto *ResumeAddrConstant =
ConstantExpr::getExtractValue(Resumers, CoroSubFnInst::ResumeIndex);
replaceWithConstant(ResumeAddrConstant, ResumeAddr);

if (DestroyAddr.empty())
return true;

auto *DestroyAddrConstant =
ConstantExpr::getExtractValue(Resumers, CoroSubFnInst::DestroyIndex);

replaceWithConstant(ResumeAddrConstant, ResumeAddr);
replaceWithConstant(DestroyAddrConstant, DestroyAddr);

// If llvm.coro.begin refers to llvm.coro.alloc, we can elide the allocation.
if (auto *AllocInst = CoroBegin->getAlloc()) {
// FIXME: The check above is overly lax. It only checks for whether we have
// an ability to elide heap allocations, not whether it is safe to do so.
// We need to do something like:
// If for every exit from the function where coro.begin is
// live, there is a coro.free or coro.destroy dominating that exit block,
// then it is safe to elide heap allocation, since the lifetime of coroutine
// is fully enclosed in its caller.
auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant));
elideHeapAllocations(CoroBegin, FrameTy, AllocInst, AA);
}

return true;
}

Expand Down Expand Up @@ -143,20 +258,9 @@ bool CoroElide::runOnFunction(Function &F) {
if (CoroBegins.empty())
return Changed;

AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
for (auto *CB : CoroBegins)
Changed |= replaceIndirectCalls(CB);
Changed |= replaceIndirectCalls(CB, AA);

return Changed;
}

char CoroElide::ID = 0;
INITIALIZE_PASS_BEGIN(
CoroElide, "coro-elide",
"Coroutine frame allocation elision and indirect calls replacement", false,
false)
INITIALIZE_PASS_END(
CoroElide, "coro-elide",
"Coroutine frame allocation elision and indirect calls replacement", false,
false)

Pass *llvm::createCoroElidePass() { return new CoroElide(); }
64 changes: 63 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroInstr.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,57 @@ class LLVM_LIBRARY_VISIBILITY CoroSubFnInst : public IntrinsicInst {
}
};

/// This represents the llvm.coro.alloc instruction.
class LLVM_LIBRARY_VISIBILITY CoroAllocInst : public IntrinsicInst {
public:
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_alloc;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This represents the llvm.coro.frame instruction.
class LLVM_LIBRARY_VISIBILITY CoroFrameInst : public IntrinsicInst {
public:
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_frame;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This represents the llvm.coro.free instruction.
class LLVM_LIBRARY_VISIBILITY CoroFreeInst : public IntrinsicInst {
public:
// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_free;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This class represents the llvm.coro.begin instruction.
class LLVM_LIBRARY_VISIBILITY CoroBeginInst : public IntrinsicInst {
enum { MemArg, AlignArg, PromiseArg, InfoArg };
enum { MemArg, ElideArg, AlignArg, PromiseArg, InfoArg };

public:
CoroAllocInst *getAlloc() const {
if (auto *CAI = dyn_cast<CoroAllocInst>(
getArgOperand(ElideArg)->stripPointerCasts()))
return CAI;

return nullptr;
}

Value *getMem() const { return getArgOperand(MemArg); }

Constant *getRawInfo() const {
return cast<Constant>(getArgOperand(InfoArg)->stripPointerCasts());
}
Expand Down Expand Up @@ -108,6 +154,22 @@ class LLVM_LIBRARY_VISIBILITY CoroBeginInst : public IntrinsicInst {
return Result;
}

// Replaces all coro.frame intrinsics that are associated with this coro.begin
// to a replacement value and removes coro.begin and all of the coro.frame
// intrinsics.
void lowerTo(Value* Replacement) {
SmallVector<CoroFrameInst*, 4> FrameInsts;
for (auto *CF : this->users())
FrameInsts.push_back(cast<CoroFrameInst>(CF));

for (auto *CF : FrameInsts) {
CF->replaceAllUsesWith(Replacement);
CF->eraseFromParent();
}

this->eraseFromParent();
}

// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_begin;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Coroutines/CoroInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void initializeCoroCleanupPass(PassRegistry &);
namespace coro {

bool declaresIntrinsics(Module &M, std::initializer_list<StringRef>);
void replaceAllCoroFrees(CoroBeginInst *CB, Value *Replacement);

// Keeps data and helper functions for lowering coroutine intrinsics.
struct LowererBase {
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Transforms/Coroutines/Coroutines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,21 @@ bool coro::declaresIntrinsics(Module &M,

return false;
}

// Find all llvm.coro.free instructions associated with the provided coro.begin
// and replace them with the provided replacement value.
void coro::replaceAllCoroFrees(CoroBeginInst *CB, Value *Replacement) {
SmallVector<CoroFreeInst *, 4> CoroFrees;
for (User *FramePtr: CB->users())
for (User *U : FramePtr->users())
if (auto *CF = dyn_cast<CoroFreeInst>(U))
CoroFrees.push_back(CF);

if (CoroFrees.empty())
return;

for (CoroFreeInst *CF : CoroFrees) {
CF->replaceAllUsesWith(Replacement);
CF->eraseFromParent();
}
}
15 changes: 9 additions & 6 deletions llvm/test/Transforms/Coroutines/coro-elide.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; Tests that the coro.destroy and coro.resume are devirtualized where possible,
; SCC pipeline restarts and inlines the direct calls.
; RUN: opt < %s -S -inline -coro-elide | FileCheck %s
; RUN: opt < %s -S -inline -coro-elide -dce | FileCheck %s

declare void @print(i32) nounwind

Expand All @@ -22,15 +22,16 @@ define fastcc void @f.destroy(i8*) {
; a coroutine start function
define i8* @f() {
entry:
%hdl = call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
%tok = call token @llvm.coro.begin(i8* null, i8* null, i32 0, i8* null,
i8* bitcast ([2 x void (i8*)*]* @f.resumers to i8*))
%hdl = call i8* @llvm.coro.frame(token %tok)
ret i8* %hdl
}

; CHECK-LABEL: @callResume(
define void @callResume() {
entry:
; CHECK: call i8* @llvm.coro.begin
; CHECK: call token @llvm.coro.begin
%hdl = call i8* @f()

; CHECK-NEXT: call void @print(i32 0)
Expand All @@ -50,7 +51,7 @@ entry:
; CHECK-LABEL: @eh(
define void @eh() personality i8* null {
entry:
; CHECK: call i8* @llvm.coro.begin
; CHECK: call token @llvm.coro.begin
%hdl = call i8* @f()

; CHECK-NEXT: call void @print(i32 0)
Expand All @@ -70,7 +71,8 @@ ehcleanup:
; no devirtualization here, since coro.begin info parameter is null
define void @no_devirt_info_null() {
entry:
%hdl = call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, i8* null)
%tok = call token @llvm.coro.begin(i8* null, i8* null, i32 0, i8* null, i8* null)
%hdl = call i8* @llvm.coro.frame(token %tok)

; CHECK: call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
%0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
Expand Down Expand Up @@ -106,5 +108,6 @@ entry:
}


declare i8* @llvm.coro.begin(i8*, i32, i8*, i8*)
declare token @llvm.coro.begin(i8*, i8*, i32, i8*, i8*)
declare i8* @llvm.coro.frame(token)
declare i8* @llvm.coro.subfn.addr(i8*, i8)
125 changes: 125 additions & 0 deletions llvm/test/Transforms/Coroutines/coro-heap-elide.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
; Tests that the dynamic allocation and deallocation of the coroutine frame is
; elided and any tail calls referencing the coroutine frame has the tail
; call attribute removed.
; RUN: opt < %s -S -inline -coro-elide -instsimplify -simplifycfg | FileCheck %s

declare void @print(i32) nounwind

%f.frame = type {i32}

declare void @bar(i8*)

declare fastcc void @f.resume(%f.frame*)
declare fastcc void @f.destroy(%f.frame*)

declare void @may_throw()
declare i8* @CustomAlloc(i32)
declare void @CustomFree(i8*)

@f.resumers = internal constant
[2 x void (%f.frame*)*] [void (%f.frame*)* @f.resume, void (%f.frame*)* @f.destroy]

; a coroutine start function
define i8* @f() personality i8* null {
entry:
%elide = call i8* @llvm.coro.alloc()
%need.dyn.alloc = icmp ne i8* %elide, null
br i1 %need.dyn.alloc, label %coro.begin, label %dyn.alloc
dyn.alloc:
%alloc = call i8* @CustomAlloc(i32 4)
br label %coro.begin
coro.begin:
%phi = phi i8* [ %elide, %entry ], [ %alloc, %dyn.alloc ]
%beg = call token @llvm.coro.begin(i8* %phi, i8* %elide, i32 0, i8* null,
i8* bitcast ([2 x void (%f.frame*)*]* @f.resumers to i8*))
%hdl = call i8* @llvm.coro.frame(token %beg)
invoke void @may_throw()
to label %ret unwind label %ehcleanup
ret:
ret i8* %hdl

ehcleanup:
%tok = cleanuppad within none []
%mem = call i8* @llvm.coro.free(i8* %hdl)
%need.dyn.free = icmp ne i8* %mem, null
br i1 %need.dyn.free, label %dyn.free, label %if.end
dyn.free:
call void @CustomFree(i8* %mem)
br label %if.end
if.end:
cleanupret from %tok unwind to caller
}

; CHECK-LABEL: @callResume(
define void @callResume() {
entry:
; CHECK: alloca %f.frame
; CHECK-NOT: coro.begin
; CHECK-NOT: CustomAlloc
; CHECK: call void @may_throw()
%hdl = call i8* @f()

; Need to remove 'tail' from the first call to @bar
; CHECK-NOT: tail call void @bar(
; CHECK: call void @bar(
tail call void @bar(i8* %hdl)
; CHECK: tail call void @bar(
tail call void @bar(i8* null)

; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame)
%0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
%1 = bitcast i8* %0 to void (i8*)*
call fastcc void %1(i8* %hdl)

; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.destroy to void (i8*)*)(i8* %vFrame)
%2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
%3 = bitcast i8* %2 to void (i8*)*
call fastcc void %3(i8* %hdl)

; CHECK-NEXT: ret void
ret void
}

; a coroutine start function (cannot elide heap alloc, due to second argument to
; coro.begin not pointint to coro.alloc)
define i8* @f_no_elision() personality i8* null {
entry:
%alloc = call i8* @CustomAlloc(i32 4)
%beg = call token @llvm.coro.begin(i8* %alloc, i8* null, i32 0, i8* null,
i8* bitcast ([2 x void (%f.frame*)*]* @f.resumers to i8*))
%hdl = call i8* @llvm.coro.frame(token %beg)
ret i8* %hdl
}

; CHECK-LABEL: @callResume_no_elision(
define void @callResume_no_elision() {
entry:
; CHECK: call i8* @CustomAlloc(
%hdl = call i8* @f_no_elision()

; Tail call should remain tail calls
; CHECK: tail call void @bar(
tail call void @bar(i8* %hdl)
; CHECK: tail call void @bar(
tail call void @bar(i8* null)

; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8*
%0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
%1 = bitcast i8* %0 to void (i8*)*
call fastcc void %1(i8* %hdl)

; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.destroy to void (i8*)*)(i8*
%2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1)
%3 = bitcast i8* %2 to void (i8*)*
call fastcc void %3(i8* %hdl)

; CHECK-NEXT: ret void
ret void
}


declare i8* @llvm.coro.alloc()
declare i8* @llvm.coro.free(i8*)
declare token @llvm.coro.begin(i8*, i8*, i32, i8*, i8*)
declare i8* @llvm.coro.frame(token)
declare i8* @llvm.coro.subfn.addr(i8*, i8)
16 changes: 16 additions & 0 deletions llvm/test/Transforms/Coroutines/restart-trigger.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; Verifies that restart trigger forces IPO pipelines restart and the same
; coroutine is looked at by CoroSplit pass twice.
; REQUIRES: asserts
; RUN: opt < %s -S -O0 -enable-coroutines -debug-only=coro-split 2>&1 | FileCheck %s
; RUN: opt < %s -S -O1 -enable-coroutines -debug-only=coro-split 2>&1 | FileCheck %s

; CHECK: CoroSplit: Processing coroutine 'f' state: 0
; CHECK-NEXT: CoroSplit: Processing coroutine 'f' state: 1

declare token @llvm.coro.begin(i8*, i8*, i32, i8*, i8*)

; a coroutine start function
define void @f() {
call token @llvm.coro.begin(i8* null, i8* null, i32 0, i8* null, i8* null)
ret void
}