Skip to content

Commit

Permalink
[OpenACC] device_type clause Sema for Compute constructs
Browse files Browse the repository at this point in the history
device_type, also spelled as dtype, specifies the applicability of the
clauses following it, and takes a series of identifiers representing the
architectures it applies to.  As we don't have a source for the valid
architectures yet, this patch just accepts all.

Semantically, this also limits the list of clauses that can be applied
after the device_type, so this implements that as well.
  • Loading branch information
erichkeane committed May 13, 2024
1 parent d819772 commit c4a9a37
Show file tree
Hide file tree
Showing 19 changed files with 622 additions and 35 deletions.
59 changes: 59 additions & 0 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "clang/AST/StmtIterator.h"
#include "clang/Basic/OpenACCKinds.h"

#include <utility>

namespace clang {
/// This is the base type for all OpenACC Clauses.
class OpenACCClause {
Expand Down Expand Up @@ -75,6 +77,63 @@ class OpenACCClauseWithParams : public OpenACCClause {
}
};

using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;
/// A 'device_type' or 'dtype' clause, takes a list of either an 'asterisk' or
/// an identifier. The 'asterisk' means 'the rest'.
class OpenACCDeviceTypeClause final
: public OpenACCClauseWithParams,
public llvm::TrailingObjects<OpenACCDeviceTypeClause,
DeviceTypeArgument> {
// Data stored in trailing objects as IdentifierInfo* /SourceLocation pairs. A
// nullptr IdentifierInfo* represents an asterisk.
unsigned NumArchs;
OpenACCDeviceTypeClause(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc,
ArrayRef<DeviceTypeArgument> Archs,
SourceLocation EndLoc)
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
NumArchs(Archs.size()) {
assert(
(K == OpenACCClauseKind::DeviceType || K == OpenACCClauseKind::DType) &&
"Invalid clause kind for device-type");

assert(!llvm::any_of(Archs, [](const DeviceTypeArgument &Arg) {
return Arg.second.isInvalid();
}) && "Invalid SourceLocation for an argument");

assert(Archs.size() == 1 ||
!llvm::any_of(Archs,
[](const DeviceTypeArgument &Arg) {
return Arg.first == nullptr;
}) &&
"Only a single asterisk version is permitted, and must be the "
"only one");

std::uninitialized_copy(Archs.begin(), Archs.end(),
getTrailingObjects<DeviceTypeArgument>());
}

public:
static bool classof(const OpenACCClause *C) {
return C->getClauseKind() == OpenACCClauseKind::DType ||
C->getClauseKind() == OpenACCClauseKind::DeviceType;
}
bool hasAsterisk() const {
return getArchitectures().size() > 0 &&
getArchitectures()[0].first == nullptr;
}

ArrayRef<DeviceTypeArgument> getArchitectures() const {
return ArrayRef<DeviceTypeArgument>(
getTrailingObjects<DeviceTypeArgument>(), NumArchs);
}

static OpenACCDeviceTypeClause *
Create(const ASTContext &C, OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, ArrayRef<DeviceTypeArgument> Archs,
SourceLocation EndLoc);
};

/// A 'default' clause, has the optional 'none' or 'present' argument.
class OpenACCDefaultClause : public OpenACCClauseWithParams {
friend class ASTReaderStmt;
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12344,4 +12344,8 @@ def warn_acc_deprecated_alias_name
def err_acc_var_not_pointer_type
: Error<"expected pointer in '%0' clause, type is %1">;
def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
def err_acc_clause_after_device_type
: Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
"compute construct">;

} // end of sema component.
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ CLAUSE_ALIAS(PCreate, Create)
CLAUSE_ALIAS(PresentOrCreate, Create)
VISIT_CLAUSE(Default)
VISIT_CLAUSE(DevicePtr)
VISIT_CLAUSE(DeviceType)
CLAUSE_ALIAS(DType, DeviceType)
VISIT_CLAUSE(FirstPrivate)
VISIT_CLAUSE(If)
VISIT_CLAUSE(NoCreate)
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3720,7 +3720,8 @@ class Parser : public CodeCompletionHandler {
SourceLocation Loc,
llvm::SmallVectorImpl<Expr *> &IntExprs);
/// Parses the 'device-type-list', which is a list of identifiers.
bool ParseOpenACCDeviceTypeList();
bool ParseOpenACCDeviceTypeList(
llvm::SmallVector<std::pair<IdentifierInfo *, SourceLocation>> &Archs);
/// Parses the 'async-argument', which is an integral value with two
/// 'special' values that are likely negative (but come from Macros).
OpenACCIntExprParseResult ParseOpenACCAsyncArgument(OpenACCDirectiveKind DK,
Expand Down
23 changes: 22 additions & 1 deletion clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class OpenACCClause;

class SemaOpenACC : public SemaBase {
public:
// Redeclaration of the version in OpenACCClause.h.
using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;

/// A type to represent all the data for an OpenACC Clause that has been
/// parsed, but not yet created/semantically analyzed. This is effectively a
/// discriminated union on the 'Clause Kind', with all of the individual
Expand Down Expand Up @@ -60,8 +63,12 @@ class SemaOpenACC : public SemaBase {
SmallVector<Expr *> QueueIdExprs;
};

struct DeviceTypeDetails {
SmallVector<DeviceTypeArgument> Archs;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails, VarListDetails, WaitDetails>
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
Details = std::monostate{};

public:
Expand Down Expand Up @@ -209,6 +216,13 @@ class SemaOpenACC : public SemaBase {
return std::get<VarListDetails>(Details).IsZero;
}

ArrayRef<DeviceTypeArgument> getDeviceTypeArchitectures() const {
assert((ClauseKind == OpenACCClauseKind::DeviceType ||
ClauseKind == OpenACCClauseKind::DType) &&
"Only 'device_type'/'dtype' has a device-type-arg list");
return std::get<DeviceTypeDetails>(Details).Archs;
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

Expand Down Expand Up @@ -326,6 +340,13 @@ class SemaOpenACC : public SemaBase {
"Parsed clause kind does not have a wait-details");
Details = WaitDetails{DevNum, QueuesLoc, std::move(IntExprs)};
}

void setDeviceTypeDetails(llvm::SmallVector<DeviceTypeArgument> &&Archs) {
assert((ClauseKind == OpenACCClauseKind::DeviceType ||
ClauseKind == OpenACCClauseKind::DType) &&
"Only 'device_type'/'dtype' has a device-type-arg list");
Details = DeviceTypeDetails{std::move(Archs)};
}
};

SemaOpenACC(Sema &S);
Expand Down
28 changes: 27 additions & 1 deletion clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
using namespace clang;

bool OpenACCClauseWithParams::classof(const OpenACCClause *C) {
return OpenACCClauseWithCondition::classof(C) ||
return OpenACCDeviceTypeClause::classof(C) ||
OpenACCClauseWithCondition::classof(C) ||
OpenACCClauseWithExprs::classof(C);
}
bool OpenACCClauseWithExprs::classof(const OpenACCClause *C) {
Expand Down Expand Up @@ -298,6 +299,17 @@ OpenACCCreateClause::Create(const ASTContext &C, OpenACCClauseKind Spelling,
VarList, EndLoc);
}

OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
const ASTContext &C, OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, ArrayRef<DeviceTypeArgument> Archs,
SourceLocation EndLoc) {
void *Mem =
C.Allocate(OpenACCDeviceTypeClause::totalSizeToAlloc<DeviceTypeArgument>(
Archs.size()));
return new (Mem)
OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -451,3 +463,17 @@ void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
OS << ")";
}
}

void OpenACCClausePrinter::VisitDeviceTypeClause(
const OpenACCDeviceTypeClause &C) {
OS << C.getClauseKind();
OS << "(";
llvm::interleaveComma(C.getArchitectures(), OS,
[&](const DeviceTypeArgument &Arch) {
if (Arch.first == nullptr)
OS << "*";
else
OS << Arch.first;
});
OS << ")";
}
3 changes: 3 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2585,6 +2585,9 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
for (auto *E : Clause.getQueueIdExprs())
Profiler.VisitStmt(E);
}
/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitDeviceTypeClause(
const OpenACCDeviceTypeClause &Clause) {}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
if (cast<OpenACCWaitClause>(C)->hasQueuesTag())
OS << " has queues tag";
break;
case OpenACCClauseKind::DeviceType:
case OpenACCClauseKind::DType:
OS << "(";
llvm::interleaveComma(
cast<OpenACCDeviceTypeClause>(C)->getArchitectures(), OS,
[&](const DeviceTypeArgument &Arch) {
if (Arch.first == nullptr)
OS << "*";
else
OS << Arch.first->getName();
});
OS << ")";
break;
default:
// Nothing to do here.
break;
Expand Down
21 changes: 13 additions & 8 deletions clang/lib/Parse/ParseOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,24 +711,25 @@ bool Parser::ParseOpenACCIntExprList(OpenACCDirectiveKind DK,
/// device_type( device-type-list )
///
/// The device_type clause may be abbreviated to dtype.
bool Parser::ParseOpenACCDeviceTypeList() {
bool Parser::ParseOpenACCDeviceTypeList(
llvm::SmallVector<std::pair<IdentifierInfo *, SourceLocation>> &Archs) {

if (expectIdentifierOrKeyword(*this)) {
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
Parser::StopBeforeMatch);
return false;
return true;
}
ConsumeToken();
Archs.emplace_back(getCurToken().getIdentifierInfo(), ConsumeToken());

while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
ExpectAndConsume(tok::comma);

if (expectIdentifierOrKeyword(*this)) {
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
Parser::StopBeforeMatch);
return false;
return true;
}
ConsumeToken();
Archs.emplace_back(getCurToken().getIdentifierInfo(), ConsumeToken());
}
return false;
}
Expand Down Expand Up @@ -1021,16 +1022,20 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
break;
}
case OpenACCClauseKind::DType:
case OpenACCClauseKind::DeviceType:
case OpenACCClauseKind::DeviceType: {
llvm::SmallVector<std::pair<IdentifierInfo *, SourceLocation>> Archs;
if (getCurToken().is(tok::star)) {
// FIXME: We want to mark that this is an 'everything else' type of
// device_type in Sema.
ConsumeToken();
} else if (ParseOpenACCDeviceTypeList()) {
ParsedClause.setDeviceTypeDetails({{nullptr, ConsumeToken()}});
} else if (!ParseOpenACCDeviceTypeList(Archs)) {
ParsedClause.setDeviceTypeDetails(std::move(Archs));
} else {
Parens.skipToEnd();
return OpenACCCanContinue();
}
break;
}
case OpenACCClauseKind::Tile:
if (ParseOpenACCSizeExprList()) {
Parens.skipToEnd();
Expand Down
55 changes: 55 additions & 0 deletions clang/lib/Sema/SemaOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,33 @@ bool checkAlreadyHasClauseOfKind(
return false;
}

/// Implement check from OpenACC3.3: section 2.5.4:
/// Only the async, wait, num_gangs, num_workers, and vector_length clauses may
/// follow a device_type clause.
bool checkValidAfterDeviceType(
SemaOpenACC &S, const OpenACCDeviceTypeClause &DeviceTypeClause,
const SemaOpenACC::OpenACCParsedClause &NewClause) {
// This is only a requirement on compute constructs so far, so this is fine
// otherwise.
if (!isOpenACCComputeDirectiveKind(NewClause.getDirectiveKind()))
return false;
switch (NewClause.getClauseKind()) {
case OpenACCClauseKind::Async:
case OpenACCClauseKind::Wait:
case OpenACCClauseKind::NumGangs:
case OpenACCClauseKind::NumWorkers:
case OpenACCClauseKind::VectorLength:
case OpenACCClauseKind::DType:
case OpenACCClauseKind::DeviceType:
return false;
default:
S.Diag(NewClause.getBeginLoc(), diag::err_acc_clause_after_device_type)
<< NewClause.getClauseKind() << DeviceTypeClause.getClauseKind();
S.Diag(DeviceTypeClause.getBeginLoc(), diag::note_acc_previous_clause_here);
return true;
}
}

} // namespace

SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
Expand All @@ -273,6 +300,17 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
return nullptr;
}

if (const auto *DevTypeClause =
llvm::find_if(ExistingClauses,
[&](const OpenACCClause *C) {
return isa<OpenACCDeviceTypeClause>(C);
});
DevTypeClause != ExistingClauses.end()) {
if (checkValidAfterDeviceType(
*this, *cast<OpenACCDeviceTypeClause>(*DevTypeClause), Clause))
return nullptr;
}

switch (Clause.getClauseKind()) {
case OpenACCClauseKind::Default: {
// Restrictions only properly implemented on 'compute' constructs, and
Expand Down Expand Up @@ -651,6 +689,23 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
Clause.getDevNumExpr(), Clause.getQueuesLoc(), Clause.getQueueIdExprs(),
Clause.getEndLoc());
}
case OpenACCClauseKind::DType:
case OpenACCClauseKind::DeviceType: {
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
break;

// TODO OpenACC: Once we get enough of the CodeGen implemented that we have
// a source for the list of valid architectures, we need to warn on unknown
// identifiers here.

return OpenACCDeviceTypeClause::Create(
getASTContext(), Clause.getClauseKind(), Clause.getBeginLoc(),
Clause.getLParenLoc(), Clause.getDeviceTypeArchitectures(),
Clause.getEndLoc());
}
default:
break;
}
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -11480,6 +11480,16 @@ void OpenACCClauseTransform<Derived>::VisitWaitClause(
ParsedClause.getQueuesLoc(), ParsedClause.getQueueIdExprs(),
ParsedClause.getEndLoc());
}

template <typename Derived>
void OpenACCClauseTransform<Derived>::VisitDeviceTypeClause(
const OpenACCDeviceTypeClause &C) {
// Nothing to transform here, just create a new version of 'C'.
NewClause = OpenACCDeviceTypeClause::Create(
Self.getSema().getASTContext(), C.getClauseKind(),
ParsedClause.getBeginLoc(), ParsedClause.getLParenLoc(),
C.getArchitectures(), ParsedClause.getEndLoc());
}
} // namespace
template <typename Derived>
OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(
Expand Down
Loading

0 comments on commit c4a9a37

Please sign in to comment.