Skip to content

Commit

Permalink
[OpenMP] Support for the if-clause on the combined directive 'target …
Browse files Browse the repository at this point in the history
…parallel'.

The if-clause on the combined directive potentially applies to both the
'target' and the 'parallel' regions.  Codegen'ing the if-clause on the
combined directive requires additional support because the expression in
the clause must be captured by the 'target' capture statement but not
the 'parallel' capture statement.  Note that this situation arises for
other clauses such as num_threads.

The OMPIfClause class inherits OMPClauseWithPreInit to support capturing
of expressions in the clause.  A member CaptureRegion is added to
OMPClauseWithPreInit to indicate which captured statement (in this case
'target' but not 'parallel') captures these expressions.

To ensure correct codegen of captured expressions in the presence of
combined 'target' directives, OMPParallelScope was added to 'parallel'
codegen.

Reviewers: ABataev
Differential Revision: https://reviews.llvm.org/D28781

llvm-svn: 292437
  • Loading branch information
arpith-jacob committed Jan 18, 2017
1 parent 9c98244 commit fe4890a
Show file tree
Hide file tree
Showing 10 changed files with 652 additions and 40 deletions.
40 changes: 27 additions & 13 deletions clang/include/clang/AST/OpenMPClause.h
Expand Up @@ -76,10 +76,17 @@ class OMPClauseWithPreInit {
friend class OMPClauseReader;
/// Pre-initialization statement for the clause.
Stmt *PreInit;
/// Region that captures the associated stmt.
OpenMPDirectiveKind CaptureRegion;

protected:
/// Set pre-initialization statement for the clause.
void setPreInitStmt(Stmt *S) { PreInit = S; }
OMPClauseWithPreInit(const OMPClause *This) : PreInit(nullptr) {
void setPreInitStmt(Stmt *S, OpenMPDirectiveKind ThisRegion = OMPD_unknown) {
PreInit = S;
CaptureRegion = ThisRegion;
}
OMPClauseWithPreInit(const OMPClause *This)
: PreInit(nullptr), CaptureRegion(OMPD_unknown) {
assert(get(This) && "get is not tuned for pre-init.");
}

Expand All @@ -88,6 +95,8 @@ class OMPClauseWithPreInit {
const Stmt *getPreInitStmt() const { return PreInit; }
/// Get pre-initialization statement for the clause.
Stmt *getPreInitStmt() { return PreInit; }
/// Get capture region for the stmt in the clause.
OpenMPDirectiveKind getCaptureRegion() { return CaptureRegion; }
static OMPClauseWithPreInit *get(OMPClause *C);
static const OMPClauseWithPreInit *get(const OMPClause *C);
};
Expand Down Expand Up @@ -194,7 +203,7 @@ template <class T> class OMPVarListClause : public OMPClause {
/// In this example directive '#pragma omp parallel' has simple 'if' clause with
/// condition 'a > 5' and directive name modifier 'parallel'.
///
class OMPIfClause : public OMPClause {
class OMPIfClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
/// \brief Location of '('.
SourceLocation LParenLoc;
Expand Down Expand Up @@ -225,26 +234,31 @@ class OMPIfClause : public OMPClause {
///
/// \param NameModifier [OpenMP 4.1] Directive name modifier of clause.
/// \param Cond Condition of the clause.
/// \param HelperCond Helper condition for the clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param NameModifierLoc Location of directive name modifier.
/// \param ColonLoc [OpenMP 4.1] Location of ':'.
/// \param EndLoc Ending location of the clause.
///
OMPIfClause(OpenMPDirectiveKind NameModifier, Expr *Cond,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation NameModifierLoc, SourceLocation ColonLoc,
SourceLocation EndLoc)
: OMPClause(OMPC_if, StartLoc, EndLoc), LParenLoc(LParenLoc),
Condition(Cond), ColonLoc(ColonLoc), NameModifier(NameModifier),
NameModifierLoc(NameModifierLoc) {}
OMPIfClause(OpenMPDirectiveKind NameModifier, Expr *Cond, Stmt *HelperCond,
OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation NameModifierLoc,
SourceLocation ColonLoc, SourceLocation EndLoc)
: OMPClause(OMPC_if, StartLoc, EndLoc), OMPClauseWithPreInit(this),
LParenLoc(LParenLoc), Condition(Cond), ColonLoc(ColonLoc),
NameModifier(NameModifier), NameModifierLoc(NameModifierLoc) {
setPreInitStmt(HelperCond, CaptureRegion);
}

/// \brief Build an empty clause.
///
OMPIfClause()
: OMPClause(OMPC_if, SourceLocation(), SourceLocation()), LParenLoc(),
Condition(nullptr), ColonLoc(), NameModifier(OMPD_unknown),
NameModifierLoc() {}
: OMPClause(OMPC_if, SourceLocation(), SourceLocation()),
OMPClauseWithPreInit(this), LParenLoc(), Condition(nullptr), ColonLoc(),
NameModifier(OMPD_unknown), NameModifierLoc() {}

/// \brief Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Expand Up @@ -2711,6 +2711,7 @@ bool RecursiveASTVisitor<Derived>::VisitOMPClauseWithPostUpdate(

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPIfClause(OMPIfClause *C) {
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getCondition()));
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/AST/OpenMPClause.cpp
Expand Up @@ -48,9 +48,10 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
return static_cast<const OMPReductionClause *>(C);
case OMPC_linear:
return static_cast<const OMPLinearClause *>(C);
case OMPC_if:
return static_cast<const OMPIfClause *>(C);
case OMPC_default:
case OMPC_proc_bind:
case OMPC_if:
case OMPC_final:
case OMPC_num_threads:
case OMPC_safelen:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/StmtProfile.cpp
Expand Up @@ -283,6 +283,7 @@ void OMPClauseProfiler::VistOMPClauseWithPostUpdate(
}

void OMPClauseProfiler::VisitOMPIfClause(const OMPIfClause *C) {
VistOMPClauseWithPreInit(C);
if (C->getCondition())
Profiler->VisitStmt(C->getCondition());
}
Expand Down
41 changes: 30 additions & 11 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Expand Up @@ -26,7 +26,7 @@ using namespace CodeGen;
namespace {
/// Lexical scope for OpenMP executable constructs, that handles correct codegen
/// for captured expressions.
class OMPLexicalScope final : public CodeGenFunction::LexicalScope {
class OMPLexicalScope : public CodeGenFunction::LexicalScope {
void emitPreInitStmt(CodeGenFunction &CGF, const OMPExecutableDirective &S) {
for (const auto *C : S.clauses()) {
if (auto *CPI = OMPClauseWithPreInit::get(C)) {
Expand Down Expand Up @@ -54,10 +54,11 @@ class OMPLexicalScope final : public CodeGenFunction::LexicalScope {

public:
OMPLexicalScope(CodeGenFunction &CGF, const OMPExecutableDirective &S,
bool AsInlined = false)
bool AsInlined = false, bool EmitPreInitStmt = true)
: CodeGenFunction::LexicalScope(CGF, S.getSourceRange()),
InlinedShareds(CGF) {
emitPreInitStmt(CGF, S);
if (EmitPreInitStmt)
emitPreInitStmt(CGF, S);
if (AsInlined) {
if (S.hasAssociatedStmt()) {
auto *CS = cast<CapturedStmt>(S.getAssociatedStmt());
Expand All @@ -81,6 +82,22 @@ class OMPLexicalScope final : public CodeGenFunction::LexicalScope {
}
};

/// Lexical scope for OpenMP parallel construct, that handles correct codegen
/// for captured expressions.
class OMPParallelScope final : public OMPLexicalScope {
bool EmitPreInitStmt(const OMPExecutableDirective &S) {
OpenMPDirectiveKind Kind = S.getDirectiveKind();
return !isOpenMPTargetExecutionDirective(Kind) &&
isOpenMPParallelDirective(Kind);
}

public:
OMPParallelScope(CodeGenFunction &CGF, const OMPExecutableDirective &S)
: OMPLexicalScope(CGF, S,
/*AsInlined=*/false,
/*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
};

/// Private scope for OpenMP loop-based directives, that supports capturing
/// of used expression from loop statement.
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
Expand Down Expand Up @@ -1237,7 +1254,7 @@ static void emitCommonOMPParallelDirective(CodeGenFunction &CGF,
}
}

OMPLexicalScope Scope(CGF, S);
OMPParallelScope Scope(CGF, S);
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
CGF.CGM.getOpenMPRuntime().emitParallelCall(CGF, S.getLocStart(), OutlinedFn,
Expand Down Expand Up @@ -3409,17 +3426,17 @@ static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
CodeGenModule &CGM = CGF.CGM;
const CapturedStmt &CS = *cast<CapturedStmt>(S.getAssociatedStmt());

llvm::SmallVector<llvm::Value *, 16> CapturedVars;
CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);

llvm::Function *Fn = nullptr;
llvm::Constant *FnID = nullptr;

// Check if we have any if clause associated with the directive.
const Expr *IfCond = nullptr;

if (auto *C = S.getSingleClause<OMPIfClause>()) {
IfCond = C->getCondition();
// Check for the at most one if clause associated with the target region.
for (const auto *C : S.getClausesOfKind<OMPIfClause>()) {
if (C->getNameModifier() == OMPD_unknown ||
C->getNameModifier() == OMPD_target) {
IfCond = C->getCondition();
break;
}
}

// Check if we have any device clause associated with the directive.
Expand Down Expand Up @@ -3456,6 +3473,8 @@ static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
CGM.getOpenMPRuntime().emitTargetOutlinedFunction(S, ParentName, Fn, FnID,
IsOffloadEntry, CodeGen);
OMPLexicalScope Scope(CGF, S);
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
CGM.getOpenMPRuntime().emitTargetCall(CGF, S, Fn, FnID, IfCond, Device,
CapturedVars);
}
Expand Down

0 comments on commit fe4890a

Please sign in to comment.