Skip to content

Commit

Permalink
[Coroutines] Use allocator overload when available
Browse files Browse the repository at this point in the history
Summary:
Depends on https://reviews.llvm.org/D42605.

An implementation of the behavior described in `[dcl.fct.def.coroutine]/7`:
when a promise type overloads `operator new` using a "placement new"
that takes the same argument types as the coroutine function, that
overload is used when allocating the coroutine frame.

Simply passing references to the coroutine function parameters directly
to `operator new` results in invariant violations in LLVM's coroutine
splitting pass, so this implementation modifies Clang codegen to
produce allocator-specific alloc/store/loads for each parameter being
forwarded to the allocator.

Test Plan: `check-clang`

Reviewers: rsmith, GorNishanov, eric_niebler

Reviewed By: GorNishanov

Subscribers: lewissbaker, EricWF, cfe-commits

Differential Revision: https://reviews.llvm.org/D42606

llvm-svn: 325291
  • Loading branch information
modocache committed Feb 15, 2018
1 parent f3f35ef commit 9860622
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 43 deletions.
125 changes: 89 additions & 36 deletions clang/lib/Sema/SemaCoroutine.cpp
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "CoroutineStmtBuilder.h"
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Decl.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/StmtCXX.h"
Expand Down Expand Up @@ -506,24 +507,15 @@ VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {

auto RefExpr = ExprEmpty();
auto Move = Moves.find(PD);
if (Move != Moves.end()) {
// If a reference to the function parameter exists in the coroutine
// frame, use that reference.
auto *MoveDecl =
cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
RefExpr = BuildDeclRefExpr(MoveDecl, MoveDecl->getType(),
ExprValueKind::VK_LValue, FD->getLocation());
} else {
// If the function parameter doesn't exist in the coroutine frame, it
// must be a scalar value. Use it directly.
assert(!PD->getType()->getAsCXXRecordDecl() &&
"Non-scalar types should have been moved and inserted into the "
"parameter moves map");
RefExpr =
BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
ExprValueKind::VK_LValue, FD->getLocation());
}

assert(Move != Moves.end() &&
"Coroutine function parameter not inserted into move map");
// If a reference to the function parameter exists in the coroutine
// frame, use that reference.
auto *MoveDecl =
cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
RefExpr =
BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
ExprValueKind::VK_LValue, FD->getLocation());
if (RefExpr.isInvalid())
return nullptr;
CtorArgExprs.push_back(RefExpr.get());
Expand Down Expand Up @@ -1050,18 +1042,75 @@ bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {

const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;

// FIXME: Add support for stateful allocators.
// [dcl.fct.def.coroutine]/7
// Lookup allocation functions using a parameter list composed of the
// requested size of the coroutine state being allocated, followed by
// the coroutine function's arguments. If a matching allocation function
// exists, use it. Otherwise, use an allocation function that just takes
// the requested size.

FunctionDecl *OperatorNew = nullptr;
FunctionDecl *OperatorDelete = nullptr;
FunctionDecl *UnusedResult = nullptr;
bool PassAlignment = false;
SmallVector<Expr *, 1> PlacementArgs;

// [dcl.fct.def.coroutine]/7
// "The allocation function’s name is looked up in the scope of P.
// [...] If the lookup finds an allocation function in the scope of P,
// overload resolution is performed on a function call created by assembling
// an argument list. The first argument is the amount of space requested,
// and has type std::size_t. The lvalues p1 ... pn are the succeeding
// arguments."
//
// ...where "p1 ... pn" are defined earlier as:
//
// [dcl.fct.def.coroutine]/3
// "For a coroutine f that is a non-static member function, let P1 denote the
// type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
// of the function parameters; otherwise let P1 ... Pn be the types of the
// function parameters. Let p1 ... pn be lvalues denoting those objects."
if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
if (MD->isInstance() && !isLambdaCallOperator(MD)) {
ExprResult ThisExpr = S.ActOnCXXThis(Loc);
if (ThisExpr.isInvalid())
return false;
ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
if (ThisExpr.isInvalid())
return false;
PlacementArgs.push_back(ThisExpr.get());
}
}
for (auto *PD : FD.parameters()) {
if (PD->getType()->isDependentType())
continue;

// Build a reference to the parameter.
auto PDLoc = PD->getLocation();
ExprResult PDRefExpr =
S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
ExprValueKind::VK_LValue, PDLoc);
if (PDRefExpr.isInvalid())
return false;

PlacementArgs.push_back(PDRefExpr.get());
}
S.FindAllocationFunctions(Loc, SourceRange(),
/*UseGlobal*/ false, PromiseType,
/*isArray*/ false, PassAlignment, PlacementArgs,
OperatorNew, UnusedResult);
OperatorNew, UnusedResult, /*Diagnose*/ false);

// [dcl.fct.def.coroutine]/7
// "If no matching function is found, overload resolution is performed again
// on a function call created by passing just the amount of space required as
// an argument of type std::size_t."
if (!OperatorNew && !PlacementArgs.empty()) {
PlacementArgs.clear();
S.FindAllocationFunctions(Loc, SourceRange(),
/*UseGlobal*/ false, PromiseType,
/*isArray*/ false, PassAlignment,
PlacementArgs, OperatorNew, UnusedResult);
}

bool IsGlobalOverload =
OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
Expand All @@ -1080,7 +1129,8 @@ bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
OperatorNew, UnusedResult);
}

assert(OperatorNew && "expected definition of operator new to be found");
if (!OperatorNew)
return false;

if (RequiresNoThrowAlloc) {
const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>();
Expand Down Expand Up @@ -1386,25 +1436,28 @@ bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
if (PD->getType()->isDependentType())
continue;

// No need to copy scalars, LLVM will take care of them.
if (PD->getType()->getAsCXXRecordDecl()) {
ExprResult PDRefExpr = BuildDeclRefExpr(
PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope?
if (PDRefExpr.isInvalid())
return false;
ExprResult PDRefExpr =
BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
ExprValueKind::VK_LValue, Loc); // FIXME: scope?
if (PDRefExpr.isInvalid())
return false;

Expr *CExpr = castForMoving(*this, PDRefExpr.get());
Expr *CExpr = nullptr;
if (PD->getType()->getAsCXXRecordDecl() ||
PD->getType()->isRValueReferenceType())
CExpr = castForMoving(*this, PDRefExpr.get());
else
CExpr = PDRefExpr.get();

auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);

// Convert decl to a statement.
StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
if (Stmt.isInvalid())
return false;
// Convert decl to a statement.
StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
if (Stmt.isInvalid())
return false;

ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
}
ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
}
return true;
}
Expand Down
28 changes: 28 additions & 0 deletions clang/test/CodeGenCoroutines/coro-alloc.cpp
Expand Up @@ -106,6 +106,34 @@ extern "C" void f1(promise_new_tag ) {
co_return;
}

struct promise_matching_placement_new_tag {};

template<>
struct std::experimental::coroutine_traits<void, promise_matching_placement_new_tag, int, float, double> {
struct promise_type {
void *operator new(unsigned long, promise_matching_placement_new_tag,
int, float, double);
void get_return_object() {}
suspend_always initial_suspend() { return {}; }
suspend_always final_suspend() { return {}; }
void return_void() {}
};
};

// CHECK-LABEL: f1a(
extern "C" void f1a(promise_matching_placement_new_tag, int x, float y , double z) {
// CHECK: store i32 %x, i32* %x.addr, align 4
// CHECK: store float %y, float* %y.addr, align 4
// CHECK: store double %z, double* %z.addr, align 8
// CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
// CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
// CHECK: %[[INT:.+]] = load i32, i32* %x.addr, align 4
// CHECK: %[[FLOAT:.+]] = load float, float* %y.addr, align 4
// CHECK: %[[DOUBLE:.+]] = load double, double* %z.addr, align 8
// CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv34promise_matching_placement_new_tagifdEE12promise_typenwEmS1_ifd(i64 %[[SIZE]], i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
co_return;
}

struct promise_delete_tag {};

template<>
Expand Down
4 changes: 2 additions & 2 deletions clang/test/CodeGenCoroutines/coro-gro-nrvo.cpp
Expand Up @@ -41,7 +41,7 @@ coro f(int) {

// CHECK: {{.*}}[[CoroInit]]:
// CHECK: store i1 false, i1* %gro.active
// CHECK-NEXT: call void @{{.*get_return_objectEv}}(%struct.coro* sret %agg.result
// CHECK: call void @{{.*get_return_objectEv}}(%struct.coro* sret %agg.result
// CHECK-NEXT: store i1 true, i1* %gro.active
co_return;
}
Expand Down Expand Up @@ -78,7 +78,7 @@ struct coro_two {

// CHECK: {{.*}}[[InitOnSuccess]]:
// CHECK: store i1 false, i1* %gro.active
// CHECK-NEXT: call void @{{.*get_return_objectEv}}(%struct.coro_two* sret %agg.result
// CHECK: call void @{{.*get_return_objectEv}}(%struct.coro_two* sret %agg.result
// CHECK-NEXT: store i1 true, i1* %gro.active

// CHECK: [[RetLabel]]:
Expand Down
10 changes: 5 additions & 5 deletions clang/test/CodeGenCoroutines/coro-params.cpp
Expand Up @@ -69,12 +69,12 @@ void f(int val, MoveOnly moParam, MoveAndCopy mcParam) {
// CHECK: store i32 %val, i32* %[[ValAddr:.+]]

// CHECK: call i8* @llvm.coro.begin(
// CHECK-NEXT: call void @_ZN8MoveOnlyC1EOS_(%struct.MoveOnly* %[[MoCopy]], %struct.MoveOnly* dereferenceable(4) %[[MoParam]])
// CHECK: call void @_ZN8MoveOnlyC1EOS_(%struct.MoveOnly* %[[MoCopy]], %struct.MoveOnly* dereferenceable(4) %[[MoParam]])
// CHECK-NEXT: call void @_ZN11MoveAndCopyC1EOS_(%struct.MoveAndCopy* %[[McCopy]], %struct.MoveAndCopy* dereferenceable(4) %[[McParam]]) #
// CHECK-NEXT: invoke void @_ZNSt12experimental16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_typeC1Ev(

// CHECK: call void @_ZN14suspend_always12await_resumeEv(
// CHECK: %[[IntParam:.+]] = load i32, i32* %val.addr
// CHECK: %[[IntParam:.+]] = load i32, i32* %val1
// CHECK: %[[MoGep:.+]] = getelementptr inbounds %struct.MoveOnly, %struct.MoveOnly* %[[MoCopy]], i32 0, i32 0
// CHECK: %[[MoVal:.+]] = load i32, i32* %[[MoGep]]
// CHECK: %[[McGep:.+]] = getelementptr inbounds %struct.MoveAndCopy, %struct.MoveAndCopy* %[[McCopy]], i32 0, i32 0
Expand Down Expand Up @@ -150,9 +150,9 @@ struct std::experimental::coroutine_traits<void, promise_matching_constructor, i

// CHECK-LABEL: void @_Z38coroutine_matching_promise_constructor28promise_matching_constructorifd(i32, float, double)
void coroutine_matching_promise_constructor(promise_matching_constructor, int, float, double) {
// CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4
// CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4
// CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8
// CHECK: %[[INT:.+]] = load i32, i32* %5, align 4
// CHECK: %[[FLOAT:.+]] = load float, float* %6, align 4
// CHECK: %[[DOUBLE:.+]] = load double, double* %7, align 8
// CHECK: invoke void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double>::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
co_return;
}
96 changes: 96 additions & 0 deletions clang/test/SemaCXX/coroutines.cpp
Expand Up @@ -781,6 +781,102 @@ coro<T> dependent_uses_nothrow_new(T) {
}
template coro<good_promise_13> dependent_uses_nothrow_new(good_promise_13);

struct good_promise_custom_new_operator {
coro<good_promise_custom_new_operator> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
void return_void();
void unhandled_exception();
void *operator new(unsigned long, double, float, int);
};

coro<good_promise_custom_new_operator>
good_coroutine_calls_custom_new_operator(double, float, int) {
co_return;
}

struct coroutine_nonstatic_member_struct;

struct good_promise_nonstatic_member_custom_new_operator {
coro<good_promise_nonstatic_member_custom_new_operator> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
void return_void();
void unhandled_exception();
void *operator new(unsigned long, coroutine_nonstatic_member_struct &, double);
};

struct bad_promise_nonstatic_member_mismatched_custom_new_operator {
coro<bad_promise_nonstatic_member_mismatched_custom_new_operator> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
void return_void();
void unhandled_exception();
// expected-note@+1 {{candidate function not viable: requires 2 arguments, but 1 was provided}}
void *operator new(unsigned long, double);
};

struct coroutine_nonstatic_member_struct {
coro<good_promise_nonstatic_member_custom_new_operator>
good_coroutine_calls_nonstatic_member_custom_new_operator(double) {
co_return;
}

coro<bad_promise_nonstatic_member_mismatched_custom_new_operator>
bad_coroutine_calls_nonstatic_member_mistmatched_custom_new_operator(double) {
// expected-error@-1 {{no matching function for call to 'operator new'}}
co_return;
}
};

struct bad_promise_mismatched_custom_new_operator {
coro<bad_promise_mismatched_custom_new_operator> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
void return_void();
void unhandled_exception();
// expected-note@+1 {{candidate function not viable: requires 4 arguments, but 1 was provided}}
void *operator new(unsigned long, double, float, int);
};

coro<bad_promise_mismatched_custom_new_operator>
bad_coroutine_calls_mismatched_custom_new_operator(double) {
// expected-error@-1 {{no matching function for call to 'operator new'}}
co_return;
}

struct bad_promise_throwing_custom_new_operator {
static coro<bad_promise_throwing_custom_new_operator> get_return_object_on_allocation_failure();
coro<bad_promise_throwing_custom_new_operator> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
void return_void();
void unhandled_exception();
// expected-error@+1 {{'operator new' is required to have a non-throwing noexcept specification when the promise type declares 'get_return_object_on_allocation_failure()'}}
void *operator new(unsigned long, double, float, int);
};

coro<bad_promise_throwing_custom_new_operator>
bad_coroutine_calls_throwing_custom_new_operator(double, float, int) {
// expected-note@-1 {{call to 'operator new' implicitly required by coroutine function here}}
co_return;
}

struct good_promise_noexcept_custom_new_operator {
static coro<good_promise_noexcept_custom_new_operator> get_return_object_on_allocation_failure();
coro<good_promise_noexcept_custom_new_operator> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
void return_void();
void unhandled_exception();
void *operator new(unsigned long, double, float, int) noexcept;
};

coro<good_promise_noexcept_custom_new_operator>
good_coroutine_calls_noexcept_custom_new_operator(double, float, int) {
co_return;
}

struct mismatch_gro_type_tag1 {};
template<>
struct std::experimental::coroutine_traits<int, mismatch_gro_type_tag1> {
Expand Down

0 comments on commit 9860622

Please sign in to comment.