Skip to content

Commit

Permalink
[OpenACC] Implement 'num_workers' clause for compute constructs (#89151)
Browse files Browse the repository at this point in the history
This clause just takes an 'int expr', which is not optional. This patch
implements the clause on compute constructs.
  • Loading branch information
erichkeane committed Apr 18, 2024
1 parent 8ba0041 commit 76600ae
Show file tree
Hide file tree
Showing 18 changed files with 746 additions and 27 deletions.
40 changes: 40 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,46 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
Expr *ConditionExpr, SourceLocation EndLoc);
};

/// Represents oen of a handful of classes that have a single integer
/// expression.
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithParams {
Expr *IntExpr;

protected:
OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
SourceLocation EndLoc)
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
IntExpr(IntExpr) {}

public:
bool hasIntExpr() const { return IntExpr; }
const Expr *getIntExpr() const { return IntExpr; }

Expr *getIntExpr() { return IntExpr; };

child_range children() {
return child_range(reinterpret_cast<Stmt **>(&IntExpr),
reinterpret_cast<Stmt **>(&IntExpr + 1));
}

const_child_range children() const {
return const_child_range(reinterpret_cast<Stmt *const *>(&IntExpr),
reinterpret_cast<Stmt *const *>(&IntExpr + 1));
}
};

class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);

public:
static OpenACCNumWorkersClause *Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *IntExpr, SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12271,4 +12271,16 @@ def warn_acc_if_self_conflict
: Warning<"OpenACC construct 'self' has no effect when an 'if' clause "
"evaluates to true">,
InGroup<DiagGroup<"openacc-self-if-potential-conflict">>;
def err_acc_int_expr_requires_integer
: Error<"OpenACC %select{clause|directive}0 '%1' requires expression of "
"integer type (%2 invalid)">;
def err_acc_int_expr_incomplete_class_type
: Error<"OpenACC integer expression has incomplete class type %0">;
def err_acc_int_expr_explicit_conversion
: Error<"OpenACC integer expression type %0 requires explicit conversion "
"to %1">;
def note_acc_int_expr_conversion
: Note<"conversion to %select{integral|enumeration}0 type %1">;
def err_acc_int_expr_multiple_conversions
: Error<"multiple conversions from expression type %0 to an integral type">;
} // end of sema component.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
VISIT_CLAUSE(Default)
VISIT_CLAUSE(If)
VISIT_CLAUSE(Self)
VISIT_CLAUSE(NumWorkers)

#undef VISIT_CLAUSE
9 changes: 5 additions & 4 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3640,13 +3640,14 @@ class Parser : public CodeCompletionHandler {
/// Parses the clause-list for an OpenACC directive.
SmallVector<OpenACCClause *>
ParseOpenACCClauseList(OpenACCDirectiveKind DirKind);
bool ParseOpenACCWaitArgument();
bool ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective);
/// Parses the clause of the 'bind' argument, which can be a string literal or
/// an ID expression.
ExprResult ParseOpenACCBindClauseArgument();
/// Parses the clause kind of 'int-expr', which can be any integral
/// expression.
ExprResult ParseOpenACCIntExpr();
ExprResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
SourceLocation Loc);
/// Parses the 'device-type-list', which is a list of identifiers.
bool ParseOpenACCDeviceTypeList();
/// Parses the 'async-argument', which is an integral value with two
Expand All @@ -3657,9 +3658,9 @@ class Parser : public CodeCompletionHandler {
/// Parses a comma delimited list of 'size-expr's.
bool ParseOpenACCSizeExprList();
/// Parses a 'gang-arg-list', used for the 'gang' clause.
bool ParseOpenACCGangArgList();
bool ParseOpenACCGangArgList(SourceLocation GangLoc);
/// Parses a 'gang-arg', used for the 'gang' clause.
bool ParseOpenACCGangArg();
bool ParseOpenACCGangArg(SourceLocation GangLoc);
/// Parses a 'condition' expr, ensuring it results in a
ExprResult ParseOpenACCConditionExpr();

Expand Down
36 changes: 34 additions & 2 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ class SemaOpenACC : public SemaBase {
Expr *ConditionExpr;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails> Details =
std::monostate{};
struct IntExprDetails {
SmallVector<Expr *> IntExprs;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails>
Details = std::monostate{};

public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
Expand Down Expand Up @@ -87,6 +92,22 @@ class SemaOpenACC : public SemaBase {
return std::get<ConditionDetails>(Details).ConditionExpr;
}

unsigned getNumIntExprs() const {
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
"Parsed clause kind does not have a int exprs");
return std::get<IntExprDetails>(Details).IntExprs.size();
}

ArrayRef<Expr *> getIntExprs() {
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
"Parsed clause kind does not have a int exprs");
return std::get<IntExprDetails>(Details).IntExprs;
}

ArrayRef<Expr *> getIntExprs() const {
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

Expand All @@ -109,6 +130,12 @@ class SemaOpenACC : public SemaBase {

Details = ConditionDetails{ConditionExpr};
}

void setIntExprDetails(ArrayRef<Expr *> IntExprs) {
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
"Parsed clause kind does not have a int exprs");
Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
}
};

SemaOpenACC(Sema &S);
Expand Down Expand Up @@ -148,6 +175,11 @@ class SemaOpenACC : public SemaBase {
/// Called after the directive has been completely parsed, including the
/// declaration group or associated statement.
DeclGroupRef ActOnEndDeclDirective();

/// Called when encountering an 'int-expr' for OpenACC, and manages
/// conversions and diagnostics to 'int'.
ExprResult ActOnIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
SourceLocation Loc, Expr *IntExpr);
};

} // namespace clang
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,27 @@ OpenACCClause::child_range OpenACCClause::children() {
return child_range(child_iterator(), child_iterator());
}

OpenACCNumWorkersClause::OpenACCNumWorkersClause(SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *IntExpr,
SourceLocation EndLoc)
: OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::NumWorkers, BeginLoc,
LParenLoc, IntExpr, EndLoc) {
assert((!IntExpr || IntExpr->isInstantiationDependent() ||
IntExpr->getType()->isIntegerType()) &&
"Condition expression type not scalar/dependent");
}

OpenACCNumWorkersClause *
OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCNumWorkersClause),
alignof(OpenACCNumWorkersClause));
return new (Mem)
OpenACCNumWorkersClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
Expand All @@ -98,3 +119,8 @@ void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
if (const Expr *CondExpr = C.getConditionExpr())
OS << "(" << CondExpr << ")";
}

void OpenACCClausePrinter::VisitNumWorkersClause(
const OpenACCNumWorkersClause &C) {
OS << "num_workers(" << C.getIntExpr() << ")";
}
7 changes: 7 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,13 @@ void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
if (Clause.hasConditionExpr())
Profiler.VisitStmt(Clause.getConditionExpr());
}

void OpenACCClauseProfiler::VisitNumWorkersClause(
const OpenACCNumWorkersClause &Clause) {
assert(Clause.hasIntExpr() && "num_workers clause requires a valid int expr");
Profiler.VisitStmt(Clause.getIntExpr());
}

} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
break;
case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::NumWorkers:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
OS << " clause";
Expand Down
57 changes: 40 additions & 17 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,10 +632,16 @@ Parser::ParseOpenACCClauseList(OpenACCDirectiveKind DirKind) {
return Clauses;
}

ExprResult Parser::ParseOpenACCIntExpr() {
// FIXME: this is required to be an integer expression (or dependent), so we
// should ensure that is the case by passing this to SEMA here.
return getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
ExprResult Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
OpenACCClauseKind CK,
SourceLocation Loc) {
ExprResult ER =
getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());

if (!ER.isUsable())
return ER;

return getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get());
}

bool Parser::ParseOpenACCClauseVarList(OpenACCClauseKind Kind) {
Expand Down Expand Up @@ -739,7 +745,7 @@ bool Parser::ParseOpenACCSizeExprList() {
/// [num:]int-expr
/// dim:int-expr
/// static:size-expr
bool Parser::ParseOpenACCGangArg() {
bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {

if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Static, getCurToken()) &&
NextToken().is(tok::colon)) {
Expand All @@ -753,7 +759,9 @@ bool Parser::ParseOpenACCGangArg() {
NextToken().is(tok::colon)) {
ConsumeToken();
ConsumeToken();
return ParseOpenACCIntExpr().isInvalid();
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
OpenACCClauseKind::Gang, GangLoc)
.isInvalid();
}

if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Num, getCurToken()) &&
Expand All @@ -763,11 +771,13 @@ bool Parser::ParseOpenACCGangArg() {
// Fallthrough to the 'int-expr' handling for when 'num' is omitted.
}
// This is just the 'num' case where 'num' is optional.
return ParseOpenACCIntExpr().isInvalid();
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
OpenACCClauseKind::Gang, GangLoc)
.isInvalid();
}

bool Parser::ParseOpenACCGangArgList() {
if (ParseOpenACCGangArg()) {
bool Parser::ParseOpenACCGangArgList(SourceLocation GangLoc) {
if (ParseOpenACCGangArg(GangLoc)) {
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
Parser::StopBeforeMatch);
return false;
Expand All @@ -776,7 +786,7 @@ bool Parser::ParseOpenACCGangArgList() {
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
ExpectAndConsume(tok::comma);

if (ParseOpenACCGangArg()) {
if (ParseOpenACCGangArg(GangLoc)) {
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
Parser::StopBeforeMatch);
return false;
Expand Down Expand Up @@ -941,11 +951,18 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
case OpenACCClauseKind::DeviceNum:
case OpenACCClauseKind::DefaultAsync:
case OpenACCClauseKind::VectorLength: {
ExprResult IntExpr = ParseOpenACCIntExpr();
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
ClauseKind, ClauseLoc);
if (IntExpr.isInvalid()) {
Parens.skipToEnd();
return OpenACCCanContinue();
}

// TODO OpenACC: as we implement the 'rest' of the above, this 'if' should
// be removed leaving just the 'setIntExprDetails'.
if (ClauseKind == OpenACCClauseKind::NumWorkers)
ParsedClause.setIntExprDetails(IntExpr.get());

break;
}
case OpenACCClauseKind::DType:
Expand Down Expand Up @@ -998,7 +1015,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
? OpenACCSpecialTokenKind::Length
: OpenACCSpecialTokenKind::Num,
ClauseKind);
ExprResult IntExpr = ParseOpenACCIntExpr();
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
ClauseKind, ClauseLoc);
if (IntExpr.isInvalid()) {
Parens.skipToEnd();
return OpenACCCanContinue();
Expand All @@ -1014,13 +1032,14 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
break;
}
case OpenACCClauseKind::Gang:
if (ParseOpenACCGangArgList()) {
if (ParseOpenACCGangArgList(ClauseLoc)) {
Parens.skipToEnd();
return OpenACCCanContinue();
}
break;
case OpenACCClauseKind::Wait:
if (ParseOpenACCWaitArgument()) {
if (ParseOpenACCWaitArgument(ClauseLoc,
/*IsDirective=*/false)) {
Parens.skipToEnd();
return OpenACCCanContinue();
}
Expand Down Expand Up @@ -1052,7 +1071,7 @@ ExprResult Parser::ParseOpenACCAsyncArgument() {
/// In this section and throughout the specification, the term wait-argument
/// means:
/// [ devnum : int-expr : ] [ queues : ] async-argument-list
bool Parser::ParseOpenACCWaitArgument() {
bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
// [devnum : int-expr : ]
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::DevNum, Tok) &&
NextToken().is(tok::colon)) {
Expand All @@ -1061,7 +1080,11 @@ bool Parser::ParseOpenACCWaitArgument() {
// Consume colon.
ConsumeToken();

ExprResult IntExpr = ParseOpenACCIntExpr();
ExprResult IntExpr = ParseOpenACCIntExpr(
IsDirective ? OpenACCDirectiveKind::Wait
: OpenACCDirectiveKind::Invalid,
IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
Loc);
if (IntExpr.isInvalid())
return true;

Expand Down Expand Up @@ -1245,7 +1268,7 @@ Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
break;
case OpenACCDirectiveKind::Wait:
// OpenACC has an optional paren-wrapped 'wait-argument'.
if (ParseOpenACCWaitArgument())
if (ParseOpenACCWaitArgument(StartLoc, /*IsDirective=*/true))
T.skipToEnd();
else
T.consumeClose();
Expand Down

0 comments on commit 76600ae

Please sign in to comment.