Skip to content

Commit

Permalink
[OpenACC] Implement 'if' clause for Compute Constructs (#88411)
Browse files Browse the repository at this point in the history
Like with the 'default' clause, this is being applied to only Compute
Constructs for now. The 'if' clause takes a condition expression which
is used as a runtime value.

This is not a particularly complex semantic implementation, as there
isn't much to this clause, other than its interactions with 'self',
  which will be managed in the patch to implement that.
  • Loading branch information
erichkeane committed Apr 12, 2024
1 parent c6cd460 commit daa8836
Show file tree
Hide file tree
Showing 18 changed files with 557 additions and 37 deletions.
3 changes: 2 additions & 1 deletion clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ class ASTNodeTraverser
void Visit(const OpenACCClause *C) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
// TODO OpenACC: Switch on clauses that have children, and add them.
for (const auto *S : C->children())
Visit(S);
});
}

Expand Down
81 changes: 77 additions & 4 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef LLVM_CLANG_AST_OPENACCCLAUSE_H
#define LLVM_CLANG_AST_OPENACCCLAUSE_H
#include "clang/AST/ASTContext.h"
#include "clang/AST/StmtIterator.h"
#include "clang/Basic/OpenACCKinds.h"

namespace clang {
Expand All @@ -34,6 +35,17 @@ class OpenACCClause {

static bool classof(const OpenACCClause *) { return true; }

using child_iterator = StmtIterator;
using const_child_iterator = ConstStmtIterator;
using child_range = llvm::iterator_range<child_iterator>;
using const_child_range = llvm::iterator_range<const_child_iterator>;

child_range children();
const_child_range children() const {
auto Children = const_cast<OpenACCClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

virtual ~OpenACCClause() = default;
};

Expand All @@ -49,6 +61,13 @@ class OpenACCClauseWithParams : public OpenACCClause {

public:
SourceLocation getLParenLoc() const { return LParenLoc; }

child_range children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}
};

/// A 'default' clause, has the optional 'none' or 'present' argument.
Expand Down Expand Up @@ -81,6 +100,51 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams {
SourceLocation EndLoc);
};

/// Represents one of the handful of classes that has an optional/required
/// 'condition' expression as an argument.
class OpenACCClauseWithCondition : public OpenACCClauseWithParams {
Expr *ConditionExpr = nullptr;

protected:
OpenACCClauseWithCondition(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *ConditionExpr,
SourceLocation EndLoc)
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
ConditionExpr(ConditionExpr) {}

public:
bool hasConditionExpr() const { return ConditionExpr; }
const Expr *getConditionExpr() const { return ConditionExpr; }
Expr *getConditionExpr() { return ConditionExpr; }

child_range children() {
if (ConditionExpr)
return child_range(reinterpret_cast<Stmt **>(&ConditionExpr),
reinterpret_cast<Stmt **>(&ConditionExpr + 1));
return child_range(child_iterator(), child_iterator());
}

const_child_range children() const {
if (ConditionExpr)
return const_child_range(
reinterpret_cast<Stmt *const *>(&ConditionExpr),
reinterpret_cast<Stmt *const *>(&ConditionExpr + 1));
return const_child_range(const_child_iterator(), const_child_iterator());
}
};

/// An 'if' clause, which has a required condition expression.
class OpenACCIfClause : public OpenACCClauseWithCondition {
protected:
OpenACCIfClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
Expr *ConditionExpr, SourceLocation EndLoc);

public:
static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *ConditionExpr,
SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand All @@ -98,6 +162,9 @@ template <class Impl> class OpenACCClauseVisitor {
case OpenACCClauseKind::Default:
VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C));
return;
case OpenACCClauseKind::If:
VisitOpenACCIfClause(*cast<OpenACCIfClause>(C));
return;
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
Expand All @@ -106,7 +173,6 @@ template <class Impl> class OpenACCClauseVisitor {
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
Expand Down Expand Up @@ -145,9 +211,13 @@ template <class Impl> class OpenACCClauseVisitor {
llvm_unreachable("Invalid Clause kind");
}

void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) {
return getDerived().VisitOpenACCDefaultClause(Clause);
#define VISIT_CLAUSE(CLAUSE_NAME) \
void VisitOpenACC##CLAUSE_NAME##Clause( \
const OpenACC##CLAUSE_NAME##Clause &Clause) { \
return getDerived().VisitOpenACC##CLAUSE_NAME##Clause(Clause); \
}

#include "clang/Basic/OpenACCClauses.def"
};

class OpenACCClausePrinter final
Expand All @@ -165,7 +235,10 @@ class OpenACCClausePrinter final
}
OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}

void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
#define VISIT_CLAUSE(CLAUSE_NAME) \
void VisitOpenACC##CLAUSE_NAME##Clause( \
const OpenACC##CLAUSE_NAME##Clause &Clause);
#include "clang/Basic/OpenACCClauses.def"
};

} // namespace clang
Expand Down
21 changes: 21 additions & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===-- OpenACCClauses.def - List of implemented OpenACC Clauses -- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a list of currently implemented OpenACC Clauses (and
// eventually, the entire list) in a way that makes generating 'visitor' and
// other lists easier.
//
// The primary macro is a single-argument version taking the name of the Clause
// as used in Clang source (so `Default` instead of `default`).
//
// VISIT_CLAUSE(CLAUSE_NAME)

VISIT_CLAUSE(Default)
VISIT_CLAUSE(If)

#undef VISIT_CLAUSE
5 changes: 5 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3611,6 +3611,9 @@ class Parser : public CodeCompletionHandler {
OpenACCClauseParseResult OpenACCCannotContinue();
OpenACCClauseParseResult OpenACCSuccess(OpenACCClause *Clause);

using OpenACCConditionExprParseResult =
std::pair<ExprResult, OpenACCParseCanContinue>;

/// Parses the OpenACC directive (the entire pragma) including the clause
/// list, but does not produce the main AST node.
OpenACCDirectiveParseInfo ParseOpenACCDirective();
Expand Down Expand Up @@ -3657,6 +3660,8 @@ class Parser : public CodeCompletionHandler {
bool ParseOpenACCGangArgList();
/// Parses a 'gang-arg', used for the 'gang' clause.
bool ParseOpenACCGangArg();
/// Parses a 'condition' expr, ensuring it results in a
ExprResult ParseOpenACCConditionExpr();

private:
//===--------------------------------------------------------------------===//
Expand Down
28 changes: 27 additions & 1 deletion clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class SemaOpenACC : public SemaBase {
OpenACCDefaultClauseKind DefaultClauseKind;
};

std::variant<DefaultDetails> Details;
struct ConditionDetails {
Expr *ConditionExpr;
};

std::variant<DefaultDetails, ConditionDetails> Details;

public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
Expand All @@ -63,6 +67,16 @@ class SemaOpenACC : public SemaBase {
return std::get<DefaultDetails>(Details).DefaultClauseKind;
}

const Expr *getConditionExpr() const {
return const_cast<OpenACCParsedClause *>(this)->getConditionExpr();
}

Expr *getConditionExpr() {
assert(ClauseKind == OpenACCClauseKind::If &&
"Parsed clause kind does not have a condition expr");
return std::get<ConditionDetails>(Details).ConditionExpr;
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

Expand All @@ -71,6 +85,18 @@ class SemaOpenACC : public SemaBase {
"Parsed clause is not a default clause");
Details = DefaultDetails{DefKind};
}

void setConditionDetails(Expr *ConditionExpr) {
assert(ClauseKind == OpenACCClauseKind::If &&
"Parsed clause kind does not have a condition expr");
// In C++ we can count on this being a 'bool', but in C this gets left as
// some sort of scalar that codegen will have to take care of converting.
assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() ||
ConditionExpr->getType()->isScalarType()) &&
"Condition expression type not scalar/dependent");

Details = ConditionDetails{ConditionExpr};
}
};

SemaOpenACC(Sema &S);
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "clang/AST/OpenACCClause.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"

using namespace clang;

Expand All @@ -27,10 +28,48 @@ OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc);
}

OpenACCIfClause *OpenACCIfClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
Expr *ConditionExpr,
SourceLocation EndLoc) {
void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause));
return new (Mem) OpenACCIfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc);
}

OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *ConditionExpr,
SourceLocation EndLoc)
: OpenACCClauseWithCondition(OpenACCClauseKind::If, BeginLoc, LParenLoc,
ConditionExpr, EndLoc) {
assert(ConditionExpr && "if clause requires condition expr");
assert((ConditionExpr->isInstantiationDependent() ||
ConditionExpr->getType()->isScalarType()) &&
"Condition expression type not scalar/dependent");
}

OpenACCClause::child_range OpenACCClause::children() {
switch (getClauseKind()) {
default:
assert(false && "Clause children function not implemented");
break;
#define VISIT_CLAUSE(CLAUSE_NAME) \
case OpenACCClauseKind::CLAUSE_NAME: \
return cast<OpenACC##CLAUSE_NAME##Clause>(this)->children();

#include "clang/Basic/OpenACCClauses.def"
}
return child_range(child_iterator(), child_iterator());
}

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
void OpenACCClausePrinter::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &C) {
OS << "default(" << C.getDefaultClauseKind() << ")";
}

void OpenACCClausePrinter::VisitOpenACCIfClause(const OpenACCIfClause &C) {
OS << "if(" << C.getConditionExpr() << ")";
}
19 changes: 16 additions & 3 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2445,9 +2445,10 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
namespace {
class OpenACCClauseProfiler
: public OpenACCClauseVisitor<OpenACCClauseProfiler> {
StmtProfiler &Profiler;

public:
OpenACCClauseProfiler() = default;
OpenACCClauseProfiler(StmtProfiler &P) : Profiler(P) {}

void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) {
for (const OpenACCClause *Clause : Clauses) {
Expand All @@ -2456,20 +2457,32 @@ class OpenACCClauseProfiler
Visit(Clause);
}
}
void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);

#define VISIT_CLAUSE(CLAUSE_NAME) \
void VisitOpenACC##CLAUSE_NAME##Clause( \
const OpenACC##CLAUSE_NAME##Clause &Clause);

#include "clang/Basic/OpenACCClauses.def"
};

/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &Clause) {}

void OpenACCClauseProfiler::VisitOpenACCIfClause(
const OpenACCIfClause &Clause) {
assert(Clause.hasConditionExpr() &&
"if clause requires a valid condition expr");
Profiler.VisitStmt(Clause.getConditionExpr());
}
} // namespace

void StmtProfiler::VisitOpenACCComputeConstruct(
const OpenACCComputeConstruct *S) {
// VisitStmt handles children, so the AssociatedStmt is handled.
VisitStmt(S);

OpenACCClauseProfiler P;
OpenACCClauseProfiler P{*this};
P.VisitOpenACCClauseList(S->clauses());
}

Expand Down
5 changes: 5 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
case OpenACCClauseKind::Default:
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
break;
case OpenACCClauseKind::If:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
OS << " clause";
break;
default:
// Nothing to do here.
break;
Expand Down

0 comments on commit daa8836

Please sign in to comment.