Skip to content

Commit

Permalink
[CUDA] Raise an error if a wrong-side call is codegen'ed.
Browse files Browse the repository at this point in the history
Summary:
Some function calls in CUDA are allowed to appear in
semantically-correct programs but are an error if they're ever
codegen'ed.  Specifically, a host+device function may call a host
function, but it's an error if such a function is ever codegen'ed in
device mode (and vice versa).

Previously, clang made no attempt to catch these errors.  For the most
part, they would be caught by ptxas, and reported as "call to unknown
function 'foo'".

Now we catch these errors and report them the same as we report other
illegal calls (e.g. a call from a host function to a device function).

This has a small change in error-message behavior for calls that were
previously disallowed (e.g. calls from a host to a device function).
Previously, we'd catch disallowed calls fairly early, before doing
additional semantic checking e.g. of the call's arguments.  Now we catch
these illegal calls at the very end of our semantic checks, so we'll
only emit a "illegal CUDA call" error if the call is otherwise
well-formed.

Reviewers: tra, rnk

Subscribers: cfe-commits

Differential Revision: https://reviews.llvm.org/D23242

llvm-svn: 278759
  • Loading branch information
Justin Lebar committed Aug 15, 2016
1 parent a7b04a5 commit 18e2d82
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 112 deletions.
12 changes: 12 additions & 0 deletions clang/include/clang/Sema/Sema.h
Expand Up @@ -9186,6 +9186,18 @@ class Sema {
void maybeAddCUDAHostDeviceAttrs(Scope *S, FunctionDecl *FD,
const LookupResult &Previous);

/// Check whether we're allowed to call Callee from the current context.
///
/// If the call is never allowed in a semantically-correct program
/// (CFP_Never), emits an error and returns false.
///
/// If the call is allowed in semantically-correct programs, but only if it's
/// never codegen'ed (CFP_WrongSide), creates a deferred diagnostic to be
/// emitted if and when the caller is codegen'ed, and returns true.
///
/// Otherwise, returns true without emitting any diagnostics.
bool CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee);

/// Finds a function in \p Matches with highest calling priority
/// from \p Caller context and erases all functions with lower
/// calling priority.
Expand Down
30 changes: 30 additions & 0 deletions clang/lib/Sema/SemaCUDA.cpp
Expand Up @@ -480,3 +480,33 @@ void Sema::maybeAddCUDAHostDeviceAttrs(Scope *S, FunctionDecl *NewD,
NewD->addAttr(CUDAHostAttr::CreateImplicit(Context));
NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context));
}

bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
assert(getLangOpts().CUDA &&
"Should only be called during CUDA compilation.");
assert(Callee && "Callee may not be null.");
FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
if (!Caller)
return true;

Sema::CUDAFunctionPreference Pref = IdentifyCUDAPreference(Caller, Callee);
if (Pref == Sema::CFP_Never) {
Diag(Loc, diag::err_ref_bad_target) << IdentifyCUDATarget(Callee) << Callee
<< IdentifyCUDATarget(Caller);
Diag(Callee->getLocation(), diag::note_previous_decl) << Callee;
return false;
}
if (Pref == Sema::CFP_WrongSide) {
// We have to do this odd dance to create our PartialDiagnostic because we
// want its storage to be allocated with operator new, not in an arena.
PartialDiagnostic PD{PartialDiagnostic::NullDiagnostic()};
PD.Reset(diag::err_ref_bad_target);
PD << IdentifyCUDATarget(Callee) << Callee << IdentifyCUDATarget(Caller);
Caller->addDeferredDiag({Loc, std::move(PD)});
Diag(Callee->getLocation(), diag::note_previous_decl) << Callee;
// This is not immediately an error, so return true. The deferred errors
// will be emitted if and when Caller is codegen'ed.
return true;
}
return true;
}
2 changes: 2 additions & 0 deletions clang/lib/Sema/SemaDeclCXX.cpp
Expand Up @@ -12274,6 +12274,8 @@ Sema::BuildCXXConstructExpr(SourceLocation ConstructLoc, QualType DeclInitType,
DeclInitType->getBaseElementTypeUnsafe()->getAsCXXRecordDecl()) &&
"given constructor for wrong type");
MarkFunctionReferenced(ConstructLoc, Constructor);
if (getLangOpts().CUDA && !CheckCUDACall(ConstructLoc, Constructor))
return ExprError();

return CXXConstructExpr::Create(
Context, DeclInitType, ConstructLoc, Constructor, Elidable,
Expand Down
147 changes: 80 additions & 67 deletions clang/lib/Sema/SemaExpr.cpp
Expand Up @@ -1739,17 +1739,9 @@ Sema::BuildDeclRefExpr(ValueDecl *D, QualType Ty, ExprValueKind VK,
const CXXScopeSpec *SS, NamedDecl *FoundD,
const TemplateArgumentListInfo *TemplateArgs) {
if (getLangOpts().CUDA)
if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
if (const FunctionDecl *Callee = dyn_cast<FunctionDecl>(D)) {
if (!IsAllowedCUDACall(Caller, Callee)) {
Diag(NameInfo.getLoc(), diag::err_ref_bad_target)
<< IdentifyCUDATarget(Callee) << D->getIdentifier()
<< IdentifyCUDATarget(Caller);
Diag(D->getLocation(), diag::note_previous_decl)
<< D->getIdentifier();
return ExprError();
}
}
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(D))
if (!CheckCUDACall(NameInfo.getLoc(), Callee))
return ExprError();

bool RefersToCapturedVariable =
isa<VarDecl>(D) &&
Expand Down Expand Up @@ -5138,37 +5130,35 @@ static bool isNumberOfArgsValidForCall(Sema &S, const FunctionDecl *Callee,
return Callee->getMinRequiredArguments() <= NumArgs;
}

/// ActOnCallExpr - Handle a call to Fn with the specified array of arguments.
/// This provides the location of the left/right parens and a list of comma
/// locations.
ExprResult
Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc,
MultiExprArg ArgExprs, SourceLocation RParenLoc,
Expr *ExecConfig, bool IsExecConfig) {
static ExprResult ActOnCallExprImpl(Sema &S, Scope *Scope, Expr *Fn,
SourceLocation LParenLoc,
MultiExprArg ArgExprs,
SourceLocation RParenLoc, Expr *ExecConfig,
bool IsExecConfig) {
// Since this might be a postfix expression, get rid of ParenListExprs.
ExprResult Result = MaybeConvertParenListExprToParenExpr(S, Fn);
ExprResult Result = S.MaybeConvertParenListExprToParenExpr(Scope, Fn);
if (Result.isInvalid()) return ExprError();
Fn = Result.get();

if (checkArgsForPlaceholders(*this, ArgExprs))
if (checkArgsForPlaceholders(S, ArgExprs))
return ExprError();

if (getLangOpts().CPlusPlus) {
if (S.getLangOpts().CPlusPlus) {
// If this is a pseudo-destructor expression, build the call immediately.
if (isa<CXXPseudoDestructorExpr>(Fn)) {
if (!ArgExprs.empty()) {
// Pseudo-destructor calls should not have any arguments.
Diag(Fn->getLocStart(), diag::err_pseudo_dtor_call_with_args)
<< FixItHint::CreateRemoval(
SourceRange(ArgExprs.front()->getLocStart(),
ArgExprs.back()->getLocEnd()));
S.Diag(Fn->getLocStart(), diag::err_pseudo_dtor_call_with_args)
<< FixItHint::CreateRemoval(
SourceRange(ArgExprs.front()->getLocStart(),
ArgExprs.back()->getLocEnd()));
}

return new (Context)
CallExpr(Context, Fn, None, Context.VoidTy, VK_RValue, RParenLoc);
return new (S.Context)
CallExpr(S.Context, Fn, None, S.Context.VoidTy, VK_RValue, RParenLoc);
}
if (Fn->getType() == Context.PseudoObjectTy) {
ExprResult result = CheckPlaceholderExpr(Fn);
if (Fn->getType() == S.Context.PseudoObjectTy) {
ExprResult result = S.CheckPlaceholderExpr(Fn);
if (result.isInvalid()) return ExprError();
Fn = result.get();
}
Expand All @@ -5183,50 +5173,53 @@ Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc,

if (Dependent) {
if (ExecConfig) {
return new (Context) CUDAKernelCallExpr(
Context, Fn, cast<CallExpr>(ExecConfig), ArgExprs,
Context.DependentTy, VK_RValue, RParenLoc);
return new (S.Context) CUDAKernelCallExpr(
S.Context, Fn, cast<CallExpr>(ExecConfig), ArgExprs,
S.Context.DependentTy, VK_RValue, RParenLoc);
} else {
return new (Context) CallExpr(
Context, Fn, ArgExprs, Context.DependentTy, VK_RValue, RParenLoc);
return new (S.Context)
CallExpr(S.Context, Fn, ArgExprs, S.Context.DependentTy, VK_RValue,
RParenLoc);
}
}

// Determine whether this is a call to an object (C++ [over.call.object]).
if (Fn->getType()->isRecordType())
return BuildCallToObjectOfClassType(S, Fn, LParenLoc, ArgExprs,
RParenLoc);
return S.BuildCallToObjectOfClassType(Scope, Fn, LParenLoc, ArgExprs,
RParenLoc);

if (Fn->getType() == Context.UnknownAnyTy) {
ExprResult result = rebuildUnknownAnyFunction(*this, Fn);
if (Fn->getType() == S.Context.UnknownAnyTy) {
ExprResult result = rebuildUnknownAnyFunction(S, Fn);
if (result.isInvalid()) return ExprError();
Fn = result.get();
}

if (Fn->getType() == Context.BoundMemberTy) {
return BuildCallToMemberFunction(S, Fn, LParenLoc, ArgExprs, RParenLoc);
if (Fn->getType() == S.Context.BoundMemberTy) {
return S.BuildCallToMemberFunction(Scope, Fn, LParenLoc, ArgExprs,
RParenLoc);
}
}

// Check for overloaded calls. This can happen even in C due to extensions.
if (Fn->getType() == Context.OverloadTy) {
if (Fn->getType() == S.Context.OverloadTy) {
OverloadExpr::FindResult find = OverloadExpr::find(Fn);

// We aren't supposed to apply this logic for if there's an '&' involved.
// We aren't supposed to apply this logic for if there'Scope an '&'
// involved.
if (!find.HasFormOfMemberPointer) {
OverloadExpr *ovl = find.Expression;
if (UnresolvedLookupExpr *ULE = dyn_cast<UnresolvedLookupExpr>(ovl))
return BuildOverloadedCallExpr(S, Fn, ULE, LParenLoc, ArgExprs,
RParenLoc, ExecConfig,
/*AllowTypoCorrection=*/true,
find.IsAddressOfOperand);
return BuildCallToMemberFunction(S, Fn, LParenLoc, ArgExprs, RParenLoc);
return S.BuildOverloadedCallExpr(
Scope, Fn, ULE, LParenLoc, ArgExprs, RParenLoc, ExecConfig,
/*AllowTypoCorrection=*/true, find.IsAddressOfOperand);
return S.BuildCallToMemberFunction(Scope, Fn, LParenLoc, ArgExprs,
RParenLoc);
}
}

// If we're directly calling a function, get the appropriate declaration.
if (Fn->getType() == Context.UnknownAnyTy) {
ExprResult result = rebuildUnknownAnyFunction(*this, Fn);
if (Fn->getType() == S.Context.UnknownAnyTy) {
ExprResult result = rebuildUnknownAnyFunction(S, Fn);
if (result.isInvalid()) return ExprError();
Fn = result.get();
}
Expand All @@ -5250,21 +5243,21 @@ Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc,
// Rewrite the function decl for this builtin by replacing parameters
// with no explicit address space with the address space of the arguments
// in ArgExprs.
if ((FDecl = rewriteBuiltinFunctionDecl(this, Context, FDecl, ArgExprs))) {
if ((FDecl =
rewriteBuiltinFunctionDecl(&S, S.Context, FDecl, ArgExprs))) {
NDecl = FDecl;
Fn = DeclRefExpr::Create(Context, FDecl->getQualifierLoc(),
SourceLocation(), FDecl, false,
SourceLocation(), FDecl->getType(),
Fn->getValueKind(), FDecl);
Fn = DeclRefExpr::Create(
S.Context, FDecl->getQualifierLoc(), SourceLocation(), FDecl, false,
SourceLocation(), FDecl->getType(), Fn->getValueKind(), FDecl);
}
}
} else if (isa<MemberExpr>(NakedFn))
NDecl = cast<MemberExpr>(NakedFn)->getMemberDecl();

if (FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(NDecl)) {
if (CallingNDeclIndirectly &&
!checkAddressOfFunctionIsAvailable(FD, /*Complain=*/true,
Fn->getLocStart()))
!S.checkAddressOfFunctionIsAvailable(FD, /*Complain=*/true,
Fn->getLocStart()))
return ExprError();

// CheckEnableIf assumes that the we're passing in a sane number of args for
Expand All @@ -5274,22 +5267,42 @@ Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc,
// number of args looks incorrect, don't do enable_if checks; we should've
// already emitted an error about the bad call.
if (FD->hasAttr<EnableIfAttr>() &&
isNumberOfArgsValidForCall(*this, FD, ArgExprs.size())) {
if (const EnableIfAttr *Attr = CheckEnableIf(FD, ArgExprs, true)) {
Diag(Fn->getLocStart(),
isa<CXXMethodDecl>(FD) ?
diag::err_ovl_no_viable_member_function_in_call :
diag::err_ovl_no_viable_function_in_call)
<< FD << FD->getSourceRange();
Diag(FD->getLocation(),
diag::note_ovl_candidate_disabled_by_enable_if_attr)
isNumberOfArgsValidForCall(S, FD, ArgExprs.size())) {
if (const EnableIfAttr *Attr = S.CheckEnableIf(FD, ArgExprs, true)) {
S.Diag(Fn->getLocStart(),
isa<CXXMethodDecl>(FD)
? diag::err_ovl_no_viable_member_function_in_call
: diag::err_ovl_no_viable_function_in_call)
<< FD << FD->getSourceRange();
S.Diag(FD->getLocation(),
diag::note_ovl_candidate_disabled_by_enable_if_attr)
<< Attr->getCond()->getSourceRange() << Attr->getMessage();
}
}
}

return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
ExecConfig, IsExecConfig);
return S.BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
ExecConfig, IsExecConfig);
}

/// ActOnCallExpr - Handle a call to Fn with the specified array of arguments.
/// This provides the location of the left/right parens and a list of comma
/// locations.
ExprResult Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc,
MultiExprArg ArgExprs, SourceLocation RParenLoc,
Expr *ExecConfig, bool IsExecConfig) {
ExprResult Ret = ActOnCallExprImpl(*this, S, Fn, LParenLoc, ArgExprs,
RParenLoc, ExecConfig, IsExecConfig);

// If appropriate, check that this is a valid CUDA call (and emit an error if
// the call is not allowed).
if (getLangOpts().CUDA && Ret.isUsable())
if (auto *Call = dyn_cast<CallExpr>(Ret.get()))
if (auto *FD = Call->getDirectCallee())
if (!CheckCUDACall(Call->getLocStart(), FD))
return ExprError();

return Ret;
}

/// ActOnAsTypeExpr - create a new asType (bitcast) from the arguments.
Expand Down
13 changes: 0 additions & 13 deletions clang/lib/Sema/SemaOverload.cpp
Expand Up @@ -12331,19 +12331,6 @@ Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE,
new (Context) CXXMemberCallExpr(Context, MemExprE, Args,
ResultType, VK, RParenLoc);

// (CUDA B.1): Check for invalid calls between targets.
if (getLangOpts().CUDA) {
if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext)) {
if (!IsAllowedCUDACall(Caller, Method)) {
Diag(MemExpr->getMemberLoc(), diag::err_ref_bad_target)
<< IdentifyCUDATarget(Method) << Method->getIdentifier()
<< IdentifyCUDATarget(Caller);
Diag(Method->getLocation(), diag::note_previous_decl) << Method;
return ExprError();
}
}
}

// Check for a valid return type.
if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(),
TheCall, Method))
Expand Down
32 changes: 0 additions & 32 deletions clang/test/CodeGenCUDA/host-device-calls-host.cu

This file was deleted.

5 changes: 5 additions & 0 deletions clang/test/SemaCUDA/Inputs/cuda.h
Expand Up @@ -21,4 +21,9 @@ typedef struct cudaStream *cudaStream_t;

int cudaConfigureCall(dim3 gridSize, dim3 blockSize, size_t sharedSize = 0,
cudaStream_t stream = 0);

// Device-side placement new overloads.
__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; }
__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; }

#endif // !__NVCC__

0 comments on commit 18e2d82

Please sign in to comment.