Skip to content

Commit

Permalink
[OpenMP] atomic compare fail : Parser & AST support
Browse files Browse the repository at this point in the history
This is a support for " #pragma omp atomic compare fail ". It has Parser & AST support for now.

Reviewed By: tianshilei1992

Differential Revision: https://reviews.llvm.org/D123235
  • Loading branch information
SunilKuravinakop authored and chichunchen committed May 25, 2022
1 parent 66db531 commit 232bf81
Show file tree
Hide file tree
Showing 19 changed files with 444 additions and 2 deletions.
11 changes: 11 additions & 0 deletions clang/include/clang/AST/ASTNodeTraverser.h
Expand Up @@ -214,13 +214,24 @@ class ASTNodeTraverser
}

void Visit(const OMPClause *C) {
if (OMPFailClause::classof(C)) {
Visit(static_cast<const OMPFailClause *>(C));
return;
}
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
for (const auto *S : C->children())
Visit(S);
});
}

void Visit(const OMPFailClause *C) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
const OMPClause *MOC = C->const_getMemoryOrderClause();
Visit(MOC);
});
}
void Visit(const GenericSelectionExpr::ConstAssociation &A) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(A);
Expand Down
127 changes: 127 additions & 0 deletions clang/include/clang/AST/OpenMPClause.h
Expand Up @@ -2266,6 +2266,133 @@ class OMPCompareClause final : public OMPClause {
}
};

/// This represents 'fail' clause in the '#pragma omp atomic'
/// directive.
///
/// \code
/// #pragma omp atomic compare fail
/// \endcode
/// In this example directive '#pragma omp atomic compare' has 'fail' clause.
class OMPFailClause final
: public OMPClause,
private llvm::TrailingObjects<OMPFailClause, SourceLocation,
OpenMPClauseKind> {
OMPClause *MemoryOrderClause;

friend class OMPClauseReader;
friend TrailingObjects;

/// Define the sizes of each trailing object array except the last one. This
/// is required for TrailingObjects to work properly.
size_t numTrailingObjects(OverloadToken<SourceLocation>) const {
// 2 locations: for '(' and argument location.
return 2;
}

/// Sets the location of '(' in fail clause.
void setLParenLoc(SourceLocation Loc) {
*getTrailingObjects<SourceLocation>() = Loc;
}

/// Sets the location of memoryOrder clause argument in fail clause.
void setArgumentLoc(SourceLocation Loc) {
*std::next(getTrailingObjects<SourceLocation>(), 1) = Loc;
}

/// Sets the mem_order clause for 'atomic compare fail' directive.
void setMemOrderClauseKind(OpenMPClauseKind MemOrder) {
OpenMPClauseKind *MOCK = getTrailingObjects<OpenMPClauseKind>();
*MOCK = MemOrder;
}

/// Sets the mem_order clause for 'atomic compare fail' directive.
void setMemOrderClause(OMPClause *MemoryOrderClauseParam) {
MemoryOrderClause = MemoryOrderClauseParam;
}
public:
/// Build 'fail' clause.
///
/// \param StartLoc Starting location of the clause.
/// \param EndLoc Ending location of the clause.
OMPFailClause(SourceLocation StartLoc, SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_fail, StartLoc, EndLoc) {}

/// Build an empty clause.
OMPFailClause()
: OMPClause(llvm::omp::OMPC_fail, SourceLocation(), SourceLocation()) {}

static OMPFailClause *CreateEmpty(const ASTContext &C);
static OMPFailClause *Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation EndLoc);

child_range children() {
return child_range(child_iterator(), child_iterator());
}


const_child_range children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}

child_range used_children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range used_children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}

static bool classof(const OMPClause *T) {
return T->getClauseKind() == llvm::omp::OMPC_fail;
}

void initFailClause(SourceLocation LParenLoc, OMPClause *MemOClause,
SourceLocation MemOrderLoc) {

setLParenLoc(LParenLoc);
MemoryOrderClause = MemOClause;
setArgumentLoc(MemOrderLoc);

OpenMPClauseKind ClauseKind = MemoryOrderClause->getClauseKind();
OpenMPClauseKind MemClauseKind = llvm::omp::OMPC_unknown;
switch(ClauseKind) {
case llvm::omp::OMPC_acq_rel:
case llvm::omp::OMPC_acquire:
MemClauseKind = llvm::omp::OMPC_acquire;
break;
case llvm::omp::OMPC_relaxed:
case llvm::omp::OMPC_release:
MemClauseKind = llvm::omp::OMPC_relaxed;
break;
case llvm::omp::OMPC_seq_cst:
MemClauseKind = llvm::omp::OMPC_seq_cst;
break;
default : break;
}
setMemOrderClauseKind(MemClauseKind);
}

/// Gets the location of '(' in fail clause.
SourceLocation getLParenLoc() const {
return *getTrailingObjects<SourceLocation>();
}

OMPClause *getMemoryOrderClause() { return MemoryOrderClause; }

const OMPClause *const_getMemoryOrderClause() const {
return static_cast<const OMPClause *>(MemoryOrderClause);
}

/// Gets the location of memoryOrder clause argument in fail clause.
SourceLocation getArgumentLoc() const {
return *std::next(getTrailingObjects<SourceLocation>(), 1);
}

/// Gets the dependence kind in clause for 'depobj' directive.
OpenMPClauseKind getMemOrderClauseKind() const {
return *getTrailingObjects<OpenMPClauseKind>();
}
};

/// This represents 'seq_cst' clause in the '#pragma omp atomic'
/// directive.
///
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Expand Up @@ -3346,6 +3346,11 @@ bool RecursiveASTVisitor<Derived>::VisitOMPCompareClause(OMPCompareClause *) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPFailClause(OMPFailClause *) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPSeqCstClause(OMPSeqCstClause *) {
return true;
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Expand Up @@ -10633,6 +10633,10 @@ def note_omp_atomic_compare: Note<
"expect binary operator in conditional expression|expect '<', '>' or '==' as order operator|expect comparison in a form of 'x == e', 'e == x', 'x ordop expr', or 'expr ordop x'|"
"expect lvalue for result value|expect scalar value|expect integer value|unexpected 'else' statement|expect '==' operator|expect an assignment statement 'v = x'|"
"expect a 'if' statement|expect no more than two statements|expect a compound statement|expect 'else' statement|expect a form 'r = x == e; if (r) ...'}0">;
def err_omp_atomic_fail_wrong_or_no_clauses : Error<"expected a memory order clause">;
def err_omp_atomic_fail_extra_mem_order_clauses : Error<"directive '#pragma omp atomic compare fail' cannot contain more than one memory order clause">;
def err_omp_atomic_fail_extra_clauses : Error<"directive '#pragma omp atomic compare' cannot contain more than one fail clause">;
def err_omp_atomic_fail_no_compare : Error<"expected 'compare' clause with the 'fail' modifier">;
def err_omp_atomic_several_clauses : Error<
"directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
def err_omp_several_mem_order_clauses : Error<
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Parse/Parser.h
Expand Up @@ -433,6 +433,8 @@ class Parser : public CodeCompletionHandler {
/// a statement expression and builds a suitable expression statement.
StmtResult handleExprStmt(ExprResult E, ParsedStmtContext StmtCtx);

OMPClause *ParseOpenMPFailClause(OMPClause *Clause);

public:
Parser(Preprocessor &PP, Sema &Actions, bool SkipFunctionBodies);
~Parser() override;
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Sema.h
Expand Up @@ -11357,6 +11357,9 @@ class Sema final {
/// Called on well-formed 'compare' clause.
OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc,
SourceLocation EndLoc);
/// Called on well-formed 'fail' clause.
OMPClause *ActOnOpenMPFailClause(SourceLocation StartLoc,
SourceLocation EndLoc);
/// Called on well-formed 'seq_cst' clause.
OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
SourceLocation EndLoc);
Expand Down
31 changes: 31 additions & 0 deletions clang/lib/AST/OpenMPClause.cpp
Expand Up @@ -127,6 +127,7 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
case OMPC_update:
case OMPC_capture:
case OMPC_compare:
case OMPC_fail:
case OMPC_seq_cst:
case OMPC_acq_rel:
case OMPC_acquire:
Expand Down Expand Up @@ -220,6 +221,7 @@ const OMPClauseWithPostUpdate *OMPClauseWithPostUpdate::get(const OMPClause *C)
case OMPC_update:
case OMPC_capture:
case OMPC_compare:
case OMPC_fail:
case OMPC_seq_cst:
case OMPC_acq_rel:
case OMPC_acquire:
Expand Down Expand Up @@ -412,6 +414,25 @@ OMPUpdateClause *OMPUpdateClause::CreateEmpty(const ASTContext &C,
return Clause;
}

OMPFailClause *OMPFailClause::Create(const ASTContext &C,
SourceLocation StartLoc,
SourceLocation EndLoc) {
void *Mem =
C.Allocate(totalSizeToAlloc<SourceLocation, OpenMPClauseKind>(2, 1),
alignof(OMPFailClause));
auto *Clause =
new (Mem) OMPFailClause(StartLoc, EndLoc);
return Clause;
}

OMPFailClause *OMPFailClause::CreateEmpty(const ASTContext &C) {
void *Mem =
C.Allocate(totalSizeToAlloc<SourceLocation, OpenMPClauseKind>(2, 1),
alignof(OMPFailClause));
auto *Clause = new (Mem) OMPFailClause();
return Clause;
}

void OMPPrivateClause::setPrivateCopies(ArrayRef<Expr *> VL) {
assert(VL.size() == varlist_size() &&
"Number of private copies is not the same as the preallocated buffer");
Expand Down Expand Up @@ -1847,6 +1868,16 @@ void OMPClausePrinter::VisitOMPCompareClause(OMPCompareClause *) {
OS << "compare";
}

void OMPClausePrinter::VisitOMPFailClause(OMPFailClause *Node) {
OS << "fail";
if(Node) {
OS << "(";
OS << getOpenMPSimpleClauseTypeName(Node->getClauseKind(),
static_cast<int>(Node->getMemOrderClauseKind()));
OS << ")";
}
}

void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) {
OS << "seq_cst";
}
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Expand Up @@ -557,6 +557,8 @@ void OMPClauseProfiler::VisitOMPCaptureClause(const OMPCaptureClause *) {}

void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {}

void OMPClauseProfiler::VisitOMPFailClause(const OMPFailClause *) {}

void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}

void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/Basic/OpenMPKinds.cpp
Expand Up @@ -366,6 +366,20 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
#include "clang/Basic/OpenMPKinds.def"
}
llvm_unreachable("Invalid OpenMP 'depend' clause type");
case OMPC_fail: {
OpenMPClauseKind CK = static_cast<OpenMPClauseKind>(Type);
switch (CK) {
case OMPC_acquire:
return "acquire";
case OMPC_relaxed:
return "relaxed";
case OMPC_seq_cst:
return "seq_cst";
default:
return "unknown";
}
llvm_unreachable("Invalid OpenMP 'fail' clause modifier");
}
case OMPC_device:
switch (Type) {
case OMPC_DEVICE_unknown:
Expand Down
46 changes: 45 additions & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Expand Up @@ -3234,6 +3234,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_write:
case OMPC_capture:
case OMPC_compare:
case OMPC_fail:
case OMPC_seq_cst:
case OMPC_acq_rel:
case OMPC_acquire:
Expand Down Expand Up @@ -3620,6 +3621,45 @@ OMPClause *Parser::ParseOpenMPSimpleClause(OpenMPClauseKind Kind,
Val.getValue().Loc, Val.getValue().RLoc);
}

OMPClause *Parser::ParseOpenMPFailClause(OMPClause *Clause) {

OMPFailClause *FailClause = static_cast<OMPFailClause *>(Clause);
SourceLocation LParenLoc;
if (Tok.is(tok::l_paren)) {
LParenLoc = Tok.getLocation();
ConsumeAnyToken();
} else {
Diag(diag::err_expected_lparen_after) << getOpenMPClauseName(OMPC_fail);
return Clause;
}


OpenMPClauseKind CKind = Tok.isAnnotation()
? OMPC_unknown
: getOpenMPClauseKind(PP.getSpelling(Tok));
if (CKind == OMPC_unknown) {
Diag(diag::err_omp_expected_clause) << ("atomic compare fail");
return Clause;
}
OMPClause *MemoryOrderClause = ParseOpenMPClause(CKind, false);
SourceLocation MemOrderLoc;
// Store Memory Order SubClause for Sema.
if (MemoryOrderClause) {
MemOrderLoc = Tok.getLocation();
}

if (Tok.is(tok::r_paren)) {
FailClause->initFailClause(LParenLoc,MemoryOrderClause,MemOrderLoc);
ConsumeAnyToken();
} else {
const IdentifierInfo *Arg = Tok.getIdentifierInfo();
Diag(Tok, diag::err_expected_rparen_after)
<< (Arg ? Arg->getName() : "atomic compare fail");
}

return Clause;
}

/// Parsing of OpenMP clauses like 'ordered'.
///
/// ordered-clause:
Expand Down Expand Up @@ -3652,7 +3692,11 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPClauseKind Kind, bool ParseOnly) {

if (ParseOnly)
return nullptr;
return Actions.ActOnOpenMPClause(Kind, Loc, Tok.getLocation());
OMPClause *Clause = Actions.ActOnOpenMPClause(Kind, Loc, Tok.getLocation());
if (Kind == llvm::omp::Clause::OMPC_fail) {
Clause = ParseOpenMPFailClause(Clause);
}
return Clause;
}

/// Parsing of OpenMP clauses with single expressions and some additional
Expand Down

0 comments on commit 232bf81

Please sign in to comment.