diff --git a/clang/include/clang/Sema/Scope.h b/clang/include/clang/Sema/Scope.h index b6b5a1f3479a2..1cb2fa83e0bb3 100644 --- a/clang/include/clang/Sema/Scope.h +++ b/clang/include/clang/Sema/Scope.h @@ -534,6 +534,25 @@ class Scope { return false; } + /// Determine if this scope (or its parents) are a compute construct inside of + /// the nearest 'switch' scope. This is needed to check whether we are inside + /// of a 'duffs' device, which is an illegal branch into a compute construct. + bool isInOpenACCComputeConstructBeforeSwitch() const { + for (const Scope *S = this; S; S = S->getParent()) { + if (S->getFlags() & Scope::OpenACCComputeConstructScope) + return true; + if (S->getFlags() & Scope::SwitchScope) + return false; + + if (S->getFlags() & + (Scope::FnScope | Scope::ClassScope | Scope::BlockScope | + Scope::TemplateParamScope | Scope::FunctionPrototypeScope | + Scope::AtCatchScope | Scope::ObjCMethodScope)) + return false; + } + return false; + } + /// Determine whether this scope is a while/do/for statement, which can have /// continue statements embedded into it. bool isContinueScope() const { diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index ca2d206752744..4a15a8f6effd3 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -527,6 +527,13 @@ Sema::ActOnCaseStmt(SourceLocation CaseLoc, ExprResult LHSVal, return StmtError(); } + if (LangOpts.OpenACC && + getCurScope()->isInOpenACCComputeConstructBeforeSwitch()) { + Diag(CaseLoc, diag::err_acc_branch_in_out_compute_construct) + << /*branch*/ 0 << /*into*/ 1; + return StmtError(); + } + auto *CS = CaseStmt::Create(Context, LHSVal.get(), RHSVal.get(), CaseLoc, DotDotDotLoc, ColonLoc); getCurFunction()->SwitchStack.back().getPointer()->addSwitchCase(CS); @@ -546,6 +553,13 @@ Sema::ActOnDefaultStmt(SourceLocation DefaultLoc, SourceLocation ColonLoc, return SubStmt; } + if (LangOpts.OpenACC && + getCurScope()->isInOpenACCComputeConstructBeforeSwitch()) { + Diag(DefaultLoc, diag::err_acc_branch_in_out_compute_construct) + << /*branch*/ 0 << /*into*/ 1; + return StmtError(); + } + DefaultStmt *DS = new (Context) DefaultStmt(DefaultLoc, ColonLoc, SubStmt); getCurFunction()->SwitchStack.back().getPointer()->addSwitchCase(DS); return DS; diff --git a/clang/test/SemaOpenACC/no-branch-in-out.c b/clang/test/SemaOpenACC/no-branch-in-out.c index d070247fa65b8..eccc643245004 100644 --- a/clang/test/SemaOpenACC/no-branch-in-out.c +++ b/clang/test/SemaOpenACC/no-branch-in-out.c @@ -310,3 +310,30 @@ LABEL4:{} ptr=&&LABEL5; } + +void DuffsDevice() { + int j; + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 0: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + default: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 'a' ... 'z': // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } +} diff --git a/clang/test/SemaOpenACC/no-branch-in-out.cpp b/clang/test/SemaOpenACC/no-branch-in-out.cpp index 9affdf733ace8..e7d5683f9bc78 100644 --- a/clang/test/SemaOpenACC/no-branch-in-out.cpp +++ b/clang/test/SemaOpenACC/no-branch-in-out.cpp @@ -18,7 +18,6 @@ void ReturnTest() { template void BreakContinue() { - #pragma acc parallel for(int i =0; i < 5; ++i) { switch(i) { @@ -109,6 +108,35 @@ void BreakContinue() { } while (j ); } +template +void DuffsDevice() { + int j; + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 0: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + default: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 'a' ... 'z': // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } +} + void Instantiate() { BreakContinue(); + DuffsDevice(); }