Skip to content

Commit

Permalink
[OpenACC] Implement self clause for compute constructs (#88760)
Browse files Browse the repository at this point in the history
`self` clauses on compute constructs take an optional condition
expression. We again limit the implementation to ONLY compute constructs
to ensure we get all the rules correct for others. However, this one
will be particularly complicated, as it takes a `var-list` for `update`,
so when we get to that construct/clause combination, we need to do that
as well.

This patch also furthers uses of the `OpenACCClauses.def` as it became
useful while implementing this (as well as some other minor refactors as
I went through).

Finally, `self` and `if` clauses have an interaction with each other, if
an `if` clause evaluates to `true`, the `self` clause has no effect.
While this is intended and can be used 'meaningfully', we are warning on
this with a very granular warning, so that this edge case will be
noticed by newer users, but can be disabled trivially.
  • Loading branch information
erichkeane committed Apr 16, 2024
1 parent 254df2e commit 6133878
Show file tree
Hide file tree
Showing 19 changed files with 513 additions and 98 deletions.
65 changes: 18 additions & 47 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ class OpenACCIfClause : public OpenACCClauseWithCondition {
SourceLocation EndLoc);
};

/// A 'self' clause, which has an optional condition expression.
class OpenACCSelfClause : public OpenACCClauseWithCondition {
OpenACCSelfClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *ConditionExpr, SourceLocation EndLoc);

public:
static OpenACCSelfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *ConditionExpr, SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand All @@ -159,53 +170,13 @@ template <class Impl> class OpenACCClauseVisitor {
return;

switch (C->getClauseKind()) {
case OpenACCClauseKind::Default:
VisitDefaultClause(*cast<OpenACCDefaultClause>(C));
return;
case OpenACCClauseKind::If:
VisitIfClause(*cast<OpenACCIfClause>(C));
return;
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
case OpenACCClauseKind::Independent:
case OpenACCClauseKind::Auto:
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
case OpenACCClauseKind::Attach:
case OpenACCClauseKind::Delete:
case OpenACCClauseKind::Detach:
case OpenACCClauseKind::Device:
case OpenACCClauseKind::DevicePtr:
case OpenACCClauseKind::DeviceResident:
case OpenACCClauseKind::FirstPrivate:
case OpenACCClauseKind::Host:
case OpenACCClauseKind::Link:
case OpenACCClauseKind::NoCreate:
case OpenACCClauseKind::Present:
case OpenACCClauseKind::Private:
case OpenACCClauseKind::CopyOut:
case OpenACCClauseKind::CopyIn:
case OpenACCClauseKind::Create:
case OpenACCClauseKind::Reduction:
case OpenACCClauseKind::Collapse:
case OpenACCClauseKind::Bind:
case OpenACCClauseKind::VectorLength:
case OpenACCClauseKind::NumGangs:
case OpenACCClauseKind::NumWorkers:
case OpenACCClauseKind::DeviceNum:
case OpenACCClauseKind::DefaultAsync:
case OpenACCClauseKind::DeviceType:
case OpenACCClauseKind::DType:
case OpenACCClauseKind::Async:
case OpenACCClauseKind::Tile:
case OpenACCClauseKind::Gang:
case OpenACCClauseKind::Wait:
case OpenACCClauseKind::Invalid:
#define VISIT_CLAUSE(CLAUSE_NAME) \
case OpenACCClauseKind::CLAUSE_NAME: \
Visit##CLAUSE_NAME##Clause(*cast<OpenACC##CLAUSE_NAME##Clause>(C)); \
return;
#include "clang/Basic/OpenACCClauses.def"

default:
llvm_unreachable("Clause visitor not yet implemented");
}
llvm_unreachable("Invalid Clause kind");
Expand Down
4 changes: 1 addition & 3 deletions clang/include/clang/AST/StmtOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,7 @@ class OpenACCComputeConstruct final
Stmt *StructuredBlock)
: OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass, K, Start,
End, StructuredBlock) {
assert((K == OpenACCDirectiveKind::Parallel ||
K == OpenACCDirectiveKind::Serial ||
K == OpenACCDirectiveKind::Kernels) &&
assert(isOpenACCComputeDirectiveKind(K) &&
"Only parallel, serial, and kernels constructs should be "
"represented by this type");

Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12274,4 +12274,8 @@ def note_acc_branch_into_compute_construct
: Note<"invalid branch into OpenACC Compute Construct">;
def note_acc_branch_out_of_compute_construct
: Note<"invalid branch out of OpenACC Compute Construct">;
def warn_acc_if_self_conflict
: Warning<"OpenACC construct 'self' has no effect when an 'if' clause "
"evaluates to true">,
InGroup<DiagGroup<"openacc-self-if-potential-conflict">>;
} // end of sema component.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@

VISIT_CLAUSE(Default)
VISIT_CLAUSE(If)
VISIT_CLAUSE(Self)

#undef VISIT_CLAUSE
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/OpenACCKinds.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
return printOpenACCDirectiveKind(Out, K);
}

inline bool isOpenACCComputeDirectiveKind(OpenACCDirectiveKind K) {
return K == OpenACCDirectiveKind::Parallel ||
K == OpenACCDirectiveKind::Serial ||
K == OpenACCDirectiveKind::Kernels;
}

enum class OpenACCAtomicKind {
Read,
Write,
Expand Down
18 changes: 15 additions & 3 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class SemaOpenACC : public SemaBase {
Expr *ConditionExpr;
};

std::variant<DefaultDetails, ConditionDetails> Details;
std::variant<std::monostate, DefaultDetails, ConditionDetails> Details =
std::monostate{};

public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
Expand Down Expand Up @@ -72,8 +73,17 @@ class SemaOpenACC : public SemaBase {
}

Expr *getConditionExpr() {
assert(ClauseKind == OpenACCClauseKind::If &&
assert((ClauseKind == OpenACCClauseKind::If ||
(ClauseKind == OpenACCClauseKind::Self &&
DirKind != OpenACCDirectiveKind::Update)) &&
"Parsed clause kind does not have a condition expr");

// 'self' has an optional ConditionExpr, so be tolerant of that. This will
// assert in variant otherwise.
if (ClauseKind == OpenACCClauseKind::Self &&
std::holds_alternative<std::monostate>(Details))
return nullptr;

return std::get<ConditionDetails>(Details).ConditionExpr;
}

Expand All @@ -87,7 +97,9 @@ class SemaOpenACC : public SemaBase {
}

void setConditionDetails(Expr *ConditionExpr) {
assert(ClauseKind == OpenACCClauseKind::If &&
assert((ClauseKind == OpenACCClauseKind::If ||
(ClauseKind == OpenACCClauseKind::Self &&
DirKind != OpenACCDirectiveKind::Update)) &&
"Parsed clause kind does not have a condition expr");
// In C++ we can count on this being a 'bool', but in C this gets left as
// some sort of scalar that codegen will have to take care of converting.
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc,
"Condition expression type not scalar/dependent");
}

OpenACCSelfClause *OpenACCSelfClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *ConditionExpr,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause));
return new (Mem)
OpenACCSelfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc);
}

OpenACCSelfClause::OpenACCSelfClause(SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *ConditionExpr, SourceLocation EndLoc)
: OpenACCClauseWithCondition(OpenACCClauseKind::Self, BeginLoc, LParenLoc,
ConditionExpr, EndLoc) {
assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() ||
ConditionExpr->getType()->isScalarType()) &&
"Condition expression type not scalar/dependent");
}

OpenACCClause::child_range OpenACCClause::children() {
switch (getClauseKind()) {
default:
Expand All @@ -72,3 +92,9 @@ void OpenACCClausePrinter::VisitDefaultClause(const OpenACCDefaultClause &C) {
void OpenACCClausePrinter::VisitIfClause(const OpenACCIfClause &C) {
OS << "if(" << C.getConditionExpr() << ")";
}

void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
OS << "self";
if (const Expr *CondExpr = C.getConditionExpr())
OS << "(" << CondExpr << ")";
}
5 changes: 5 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,11 @@ void OpenACCClauseProfiler::VisitIfClause(const OpenACCIfClause &Clause) {
"if clause requires a valid condition expr");
Profiler.VisitStmt(Clause.getConditionExpr());
}

void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
if (Clause.hasConditionExpr())
Profiler.VisitStmt(Clause.getConditionExpr());
}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
break;
case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
OS << " clause";
Expand Down
16 changes: 11 additions & 5 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,19 +835,23 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::Default: {
Token DefKindTok = getCurToken();

if (expectIdentifierOrKeyword(*this))
break;
if (expectIdentifierOrKeyword(*this)) {
Parens.skipToEnd();
return OpenACCCanContinue();
}

ConsumeToken();

OpenACCDefaultClauseKind DefKind =
getOpenACCDefaultClauseKind(DefKindTok);

if (DefKind == OpenACCDefaultClauseKind::Invalid)
if (DefKind == OpenACCDefaultClauseKind::Invalid) {
Diag(DefKindTok, diag::err_acc_invalid_default_clause_kind);
else
ParsedClause.setDefaultDetails(DefKind);
Parens.skipToEnd();
return OpenACCCanContinue();
}

ParsedClause.setDefaultDetails(DefKind);
break;
}
case OpenACCClauseKind::If: {
Expand Down Expand Up @@ -977,6 +981,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::Self: {
assert(DirKind != OpenACCDirectiveKind::Update);
ExprResult CondExpr = ParseOpenACCConditionExpr();
ParsedClause.setConditionDetails(CondExpr.isUsable() ? CondExpr.get()
: nullptr);

if (CondExpr.isInvalid()) {
Parens.skipToEnd();
Expand Down
68 changes: 60 additions & 8 deletions clang/lib/Sema/SemaOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/AST/StmtOpenACC.h"
#include "clang/Basic/DiagnosticSema.h"
#include "clang/Sema/Sema.h"
#include "llvm/Support/Casting.h"

using namespace clang;

Expand Down Expand Up @@ -76,6 +77,19 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
default:
return false;
}
case OpenACCClauseKind::Self:
switch (DirectiveKind) {
case OpenACCDirectiveKind::Parallel:
case OpenACCDirectiveKind::Serial:
case OpenACCDirectiveKind::Kernels:
case OpenACCDirectiveKind::Update:
case OpenACCDirectiveKind::ParallelLoop:
case OpenACCDirectiveKind::SerialLoop:
case OpenACCDirectiveKind::KernelsLoop:
return true;
default:
return false;
}
default:
// Do nothing so we can go to the 'unimplemented' diagnostic instead.
return true;
Expand Down Expand Up @@ -121,9 +135,7 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Parallel &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Serial &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Kernels)
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
break;

// Don't add an invalid clause to the AST.
Expand All @@ -146,9 +158,7 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Parallel &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Serial &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Kernels)
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
break;

// There is no prose in the standard that says duplicates aren't allowed,
Expand All @@ -160,12 +170,54 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
// The parser has ensured that we have a proper condition expr, so there
// isn't really much to do here.

// TODO OpenACC: When we implement 'self', this clauses causes us to
// 'ignore' the self clause, so we should implement a warning here.
// If the 'if' clause is true, it makes the 'self' clause have no effect,
// diagnose that here.
// TODO OpenACC: When we add these two to other constructs, we might not
// want to warn on this (for example, 'update').
const auto *Itr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCSelfClause>);
if (Itr != ExistingClauses.end()) {
Diag(Clause.getBeginLoc(), diag::warn_acc_if_self_conflict);
Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
}

return OpenACCIfClause::Create(
getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
Clause.getConditionExpr(), Clause.getEndLoc());
}

case OpenACCClauseKind::Self: {
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
break;

// TODO OpenACC: When we implement this for 'update', this takes a
// 'var-list' instead of a condition expression, so semantics/handling has
// to happen differently here.

// There is no prose in the standard that says duplicates aren't allowed,
// but this diagnostic is present in other compilers, as well as makes
// sense.
if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
return nullptr;

// If the 'if' clause is true, it makes the 'self' clause have no effect,
// diagnose that here.
// TODO OpenACC: When we add these two to other constructs, we might not
// want to warn on this (for example, 'update').
const auto *Itr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCIfClause>);
if (Itr != ExistingClauses.end()) {
Diag(Clause.getBeginLoc(), diag::warn_acc_if_self_conflict);
Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
}

return OpenACCSelfClause::Create(
getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
Clause.getConditionExpr(), Clause.getEndLoc());
}
default:
break;
}
Expand Down

0 comments on commit 6133878

Please sign in to comment.