diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index a96f69d6ac760..ebda201361fb0 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12203,4 +12203,6 @@ def warn_acc_clause_unimplemented def err_acc_construct_appertainment : Error<"OpenACC construct '%0' cannot be used here; it can only " "be used in a statement context">; +def err_acc_branch_in_out + : Error<"invalid branch %select{out of|into}0 OpenACC region">; } // end of sema component. diff --git a/clang/include/clang/Sema/Scope.h b/clang/include/clang/Sema/Scope.h index 9e81706cd2aa1..e7f166fe3461f 100644 --- a/clang/include/clang/Sema/Scope.h +++ b/clang/include/clang/Sema/Scope.h @@ -150,6 +150,9 @@ class Scope { /// template scope in between), the outer scope does not increase the /// depth of recursion. LambdaScope = 0x8000000, + /// This is the scope of an OpenACC Compute Construct, which restricts + /// jumping into/out of it. + OpenACCComputeConstructScope = 0x10000000, }; private: @@ -469,6 +472,14 @@ class Scope { return false; } + /// Return true if this scope is a loop. + bool isLoopScope() const { + // 'switch' is the only loop that is not a 'break' scope as well, so we can + // just check BreakScope and not SwitchScope. + return (getFlags() & Scope::BreakScope) && + !(getFlags() & Scope::SwitchScope); + } + /// Determines whether this scope is the OpenMP directive scope bool isOpenMPDirectiveScope() const { return (getFlags() & Scope::OpenMPDirectiveScope); @@ -504,6 +515,12 @@ class Scope { return getFlags() & Scope::OpenMPOrderClauseScope; } + /// Determine whether this scope is the statement associated with an OpenACC + /// Compute construct directive. + bool isOpenACCComputeConstructScope() const { + return getFlags() & Scope::OpenACCComputeConstructScope; + } + /// Determine whether this scope is a while/do/for statement, which can have /// continue statements embedded into it. bool isContinueScope() const { diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp index 50e78e8687aea..4946a61fca007 100644 --- a/clang/lib/Parse/ParseOpenACC.cpp +++ b/clang/lib/Parse/ParseOpenACC.cpp @@ -560,6 +560,21 @@ bool doesDirectiveHaveAssociatedStmt(OpenACCDirectiveKind DirKind) { llvm_unreachable("Unhandled directive->assoc stmt"); } +unsigned getOpenACCScopeFlags(OpenACCDirectiveKind DirKind) { + switch (DirKind) { + case OpenACCDirectiveKind::Parallel: + // Mark this as a BreakScope/ContinueScope as well as a compute construct + // so that we can diagnose trying to 'break'/'continue' inside of one. + return Scope::BreakScope | Scope::ContinueScope | + Scope::OpenACCComputeConstructScope; + case OpenACCDirectiveKind::Invalid: + llvm_unreachable("Shouldn't be creating a scope for an invalid construct"); + default: + break; + } + return 0; +} + } // namespace // OpenACC 3.3, section 1.7: @@ -1228,6 +1243,8 @@ StmtResult Parser::ParseOpenACCDirectiveStmt() { if (doesDirectiveHaveAssociatedStmt(DirInfo.DirKind)) { ParsingOpenACCDirectiveRAII DirScope(*this, /*Value=*/false); + ParseScope ACCScope(this, getOpenACCScopeFlags(DirInfo.DirKind)); + AssocStmt = getActions().ActOnOpenACCAssociatedStmt(DirInfo.DirKind, ParseStatement()); } diff --git a/clang/lib/Sema/Scope.cpp b/clang/lib/Sema/Scope.cpp index 4570d8c615fe5..cea6a62e34747 100644 --- a/clang/lib/Sema/Scope.cpp +++ b/clang/lib/Sema/Scope.cpp @@ -225,6 +225,7 @@ void Scope::dumpImpl(raw_ostream &OS) const { {CompoundStmtScope, "CompoundStmtScope"}, {ClassInheritanceScope, "ClassInheritanceScope"}, {CatchScope, "CatchScope"}, + {OpenACCComputeConstructScope, "OpenACCComputeConstructScope"}, }; for (auto Info : FlagInfo) { diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index dde3bd84e89f8..fcad09a63662b 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -3356,6 +3356,14 @@ Sema::ActOnContinueStmt(SourceLocation ContinueLoc, Scope *CurScope) { // initialization of that variable. return StmtError(Diag(ContinueLoc, diag::err_continue_from_cond_var_init)); } + + // A 'continue' that would normally have execution continue on a block outside + // of a compute construct counts as 'branching out of' the compute construct, + // so diagnose here. + if (S->isOpenACCComputeConstructScope()) + return StmtError(Diag(ContinueLoc, diag::err_acc_branch_in_out) + << /*out of */ 0); + CheckJumpOutOfSEHFinally(*this, ContinueLoc, *S); return new (Context) ContinueStmt(ContinueLoc); @@ -3371,6 +3379,20 @@ Sema::ActOnBreakStmt(SourceLocation BreakLoc, Scope *CurScope) { if (S->isOpenMPLoopScope()) return StmtError(Diag(BreakLoc, diag::err_omp_loop_cannot_use_stmt) << "break"); + + // OpenACC doesn't allow 'break'ing from a compute construct, so diagnose if + // we are trying to do so. This can come in 2 flavors: 1-the break'able thing + // (besides the compute construct) 'contains' the compute construct, at which + // point the 'break' scope will be the compute construct. Else it could be a + // loop of some sort that has a direct parent of the compute construct. + // However, a 'break' in a 'switch' marked as a compute construct doesn't + // count as 'branch out of' the compute construct. + if (S->isOpenACCComputeConstructScope() || + (S->isLoopScope() && S->getParent() && + S->getParent()->isOpenACCComputeConstructScope())) + return StmtError(Diag(BreakLoc, diag::err_acc_branch_in_out) + << /*out of */ 0); + CheckJumpOutOfSEHFinally(*this, BreakLoc, *S); return new (Context) BreakStmt(BreakLoc); diff --git a/clang/test/SemaOpenACC/no-branch-in-out.c b/clang/test/SemaOpenACC/no-branch-in-out.c new file mode 100644 index 0000000000000..622cf55f48473 --- /dev/null +++ b/clang/test/SemaOpenACC/no-branch-in-out.c @@ -0,0 +1,95 @@ +// RUN: %clang_cc1 %s -verify -fopenacc + +void BreakContinue() { + +#pragma acc parallel + for(int i =0; i < 5; ++i) { + switch(i) { + case 0: + break; // leaves switch, not 'for'. + default: + i +=2; + break; + } + if (i == 2) + continue; + + break; // expected-error{{invalid branch out of OpenACC region}} + } + + int j; + switch(j) { + case 0: +#pragma acc parallel + { + break; // expected-error{{invalid branch out of OpenACC region}} + } + case 1: +#pragma acc parallel + { + } + break; + } + +#pragma acc parallel + for(int i = 0; i < 5; ++i) { + if (i > 1) + break; // expected-error{{invalid branch out of OpenACC region}} + } + +#pragma acc parallel + switch(j) { + case 1: + break; + } + +#pragma acc parallel + { + for(int i = 1; i < 100; i++) { + if (i > 4) + break; + } + } + + for (int i =0; i < 5; ++i) { +#pragma acc parallel + { + continue; // expected-error{{invalid branch out of OpenACC region}} + } + } + +#pragma acc parallel + for (int i =0; i < 5; ++i) { + continue; + } + +#pragma acc parallel + for (int i =0; i < 5; ++i) { + { + continue; + } + } + + for (int i =0; i < 5; ++i) { +#pragma acc parallel + { + break; // expected-error{{invalid branch out of OpenACC region}} + } + } + +#pragma acc parallel + while (j) { + --j; + if (j > 4) + break; // expected-error{{invalid branch out of OpenACC region}} + } + +#pragma acc parallel + do { + --j; + if (j > 4) + break; // expected-error{{invalid branch out of OpenACC region}} + } while (j ); + +} +