diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index ccf2c40bc5efa..e18dea7d9bd47 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -8541,6 +8541,101 @@ class OMPIsDevicePtrClause final } }; +/// This represents clause 'graph_id' in the '#pragma omp taskgraph" +/// directives. +/// +/// \code +/// #pragma omp taskgraph graph_id(a) +class OMPGraphIdClause final + : public OMPOneStmtClause { + friend class OMPClauseReader; + + /// Set condition. + void setId(Expr *Id) { setStmt(Id); } + +public: + /// Build 'graph_id' clause with identifier value \a Id. + /// + /// \param Id Id value for the clause. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPGraphIdClause(Expr *Id, SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPOneStmtClause(Id, StartLoc, LParenLoc, EndLoc) {} + + /// Build an empty clause. + OMPGraphIdClause() : OMPOneStmtClause() {} + + /// Returns condition. + Expr *getId() const { return getStmtAs(); } +}; + +// This represents clause 'graph_reset' in the '#pragma omp taskgraph" +/// directives. +/// +/// \code +/// #pragma omp taskgraph graph_reset(true) +class OMPGraphResetClause final : public OMPClause { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Condition of the 'graph_reset' clause. + Stmt *Condition = nullptr; + +public: + /// Build 'graph_reset' clause with condition \a Cond. + /// + /// \param Cond Condition of the clause. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPGraphResetClause(Expr *Cond, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_graph_reset, StartLoc, EndLoc), + LParenLoc(LParenLoc), Condition(Cond) {} + + /// Build an empty clause. + OMPGraphResetClause() + : OMPClause(llvm::omp::OMPC_graph_reset, SourceLocation(), + SourceLocation()) {} + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Set condition. + void setCondition(Expr *Cond) { Condition = Cond; } + + /// Returns condition. + Expr *getCondition() const { return cast_or_null(Condition); } + + child_range children() { + if (Condition) + return child_range(&Condition, &Condition + 1); + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + if (Condition) + return const_child_range(&Condition, &Condition + 1); + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children(); + const_child_range used_children() const { + return const_cast(this)->used_children(); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_graph_reset; + } +}; + /// This represents clause 'has_device_ptr' in the '#pragma omp ...' /// directives. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index b5be0910194bd..b86f122f75460 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -4162,6 +4162,19 @@ bool RecursiveASTVisitor::VisitOMPIsDevicePtrClause( return true; } +template +bool RecursiveASTVisitor::VisitOMPGraphIdClause(OMPGraphIdClause *C) { + TRY_TO(TraverseStmt(C->getId())); + return true; +} + +template +bool RecursiveASTVisitor::VisitOMPGraphResetClause( + OMPGraphResetClause *C) { + TRY_TO(TraverseStmt(C->getCondition())); + return true; +} + template bool RecursiveASTVisitor::VisitOMPHasDeviceAddrClause( OMPHasDeviceAddrClause *C) { diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index 3621ce96b8724..480b18960fd67 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -951,6 +951,15 @@ class SemaOpenMP : public SemaBase { ActOnOpenMPOrderedClause(SourceLocation StartLoc, SourceLocation EndLoc, SourceLocation LParenLoc = SourceLocation(), Expr *NumForLoops = nullptr); + /// Called on well-formed 'graph_id' clause. + OMPClause *ActOnOpenMPGraphIdClause(Expr *Id, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); + /// Called on well-formed 'graph_reset' clause. + OMPClause *ActOnOpenMPGraphResetClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); /// Called on well-formed 'grainsize' clause. OMPClause *ActOnOpenMPGrainsizeClause(OpenMPGrainsizeClauseModifier Modifier, Expr *Size, SourceLocation StartLoc, diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index 3a35e17aff40b..db4d5519acb38 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -254,6 +254,8 @@ const OMPClauseWithPostUpdate *OMPClauseWithPostUpdate::get(const OMPClause *C) case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: + case OMPC_graph_id: + case OMPC_graph_reset: case OMPC_nogroup: case OMPC_num_tasks: case OMPC_hint: @@ -328,6 +330,12 @@ OMPClause::child_range OMPGrainsizeClause::used_children() { return child_range(&Grainsize, &Grainsize + 1); } +OMPClause::child_range OMPGraphResetClause::used_children() { + if (Condition) + return child_range(&Condition, &Condition + 1); + return children(); +} + OMPClause::child_range OMPNumTasksClause::used_children() { if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt())) return child_range(C, C + 1); @@ -2369,6 +2377,24 @@ void OMPClausePrinter::VisitOMPGrainsizeClause(OMPGrainsizeClause *Node) { OS << ")"; } +void OMPClausePrinter::VisitOMPGraphIdClause(OMPGraphIdClause *Node) { + OS << "graph_id"; + if (Expr *E = Node->getId()) { + OS << "("; + E->printPretty(OS, nullptr, Policy, 0); + OS << ")"; + } +} + +void OMPClausePrinter::VisitOMPGraphResetClause(OMPGraphResetClause *Node) { + OS << "graph_reset"; + if (Expr *E = Node->getCondition()) { + OS << "("; + E->printPretty(OS, nullptr, Policy, 0); + OS << ")"; + } +} + void OMPClausePrinter::VisitOMPNumTasksClause(OMPNumTasksClause *Node) { OS << "num_tasks("; OpenMPNumTasksClauseModifier Modifier = Node->getModifier(); diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 8219e57644be6..2b991ff8e7b19 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -916,6 +916,14 @@ void OMPClauseProfiler::VisitOMPGrainsizeClause(const OMPGrainsizeClause *C) { if (C->getGrainsize()) Profiler->VisitStmt(C->getGrainsize()); } +void OMPClauseProfiler::VisitOMPGraphIdClause(const OMPGraphIdClause *C) { + if (C->getId()) + Profiler->VisitStmt(C->getId()); +} +void OMPClauseProfiler::VisitOMPGraphResetClause(const OMPGraphResetClause *C) { + if (C->getCondition()) + Profiler->VisitStmt(C->getCondition()); +} void OMPClauseProfiler::VisitOMPNumTasksClause(const OMPNumTasksClause *C) { VisitOMPClauseWithPreInit(C); if (C->getNumTasks()) diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp index 287eb217ba458..72015a7224275 100644 --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -312,6 +312,8 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str, case OMPC_when: case OMPC_append_args: case OMPC_looprange: + case OMPC_graph_id: + case OMPC_graph_reset: break; default: break; @@ -692,6 +694,8 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind, case OMPC_when: case OMPC_append_args: case OMPC_looprange: + case OMPC_graph_id: + case OMPC_graph_reset: break; default: break; diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index 45a47ec797f01..5be99550f1ef6 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3241,6 +3241,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, case OMPC_partial: case OMPC_align: case OMPC_message: + case OMPC_graph_id: case OMPC_ompx_dyn_cgroup_mem: case OMPC_dyn_groupprivate: case OMPC_transparent: @@ -3391,6 +3392,20 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, else Clause = ParseOpenMPClause(CKind, WrongDirective); break; + case OMPC_graph_reset: + if (!FirstClause) { + Diag(Tok, diag::err_omp_more_one_clause) + << getOpenMPDirectiveName(DKind, OMPVersion) + << getOpenMPClauseName(CKind) << 0; + ErrorFound = true; + } + + if (PP.LookAhead(/*N=*/0).is(tok::l_paren)) { + Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective); + } else { + Clause = ParseOpenMPClause(CKind, WrongDirective); + } + break; case OMPC_self_maps: // OpenMP [6.0, self_maps clause] if (getLangOpts().OpenMP < 60) { diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 53ded7a5e177e..22ce0190a0907 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -6816,6 +6816,8 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective( case OMPC_final: case OMPC_priority: case OMPC_novariants: + case OMPC_graph_id: + case OMPC_graph_reset: case OMPC_nocontext: // Do not analyze if no parent parallel directive. if (isOpenMPParallelDirective(Kind)) @@ -16858,6 +16860,12 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_detach: Res = ActOnOpenMPDetachClause(Expr, StartLoc, LParenLoc, EndLoc); break; + case OMPC_graph_id: + Res = ActOnOpenMPGraphIdClause(Expr, StartLoc, LParenLoc, EndLoc); + break; + case OMPC_graph_reset: + Res = ActOnOpenMPGraphResetClause(Expr, StartLoc, LParenLoc, EndLoc); + break; case OMPC_novariants: Res = ActOnOpenMPNovariantsClause(Expr, StartLoc, LParenLoc, EndLoc); break; @@ -17646,6 +17654,8 @@ OMPClause *SemaOpenMP::ActOnOpenMPSimpleClause( case OMPC_match: case OMPC_nontemporal: case OMPC_destroy: + case OMPC_graph_id: + case OMPC_graph_reset: case OMPC_novariants: case OMPC_nocontext: case OMPC_detach: @@ -18389,6 +18399,8 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprWithArgClause( case OMPC_severity: case OMPC_message: case OMPC_destroy: + case OMPC_graph_id: + case OMPC_graph_reset: case OMPC_novariants: case OMPC_nocontext: case OMPC_detach: @@ -18608,6 +18620,10 @@ OMPClause *SemaOpenMP::ActOnOpenMPClause(OpenMPClauseKind Kind, case OMPC_ompx_bare: Res = ActOnOpenMPXBareClause(StartLoc, EndLoc); break; + case OMPC_graph_reset: + Res = ActOnOpenMPGraphResetClause(/*Condition=*/nullptr, StartLoc, + SourceLocation(), EndLoc); + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: @@ -18662,6 +18678,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPClause(OpenMPClauseKind Kind, case OMPC_at: case OMPC_severity: case OMPC_message: + case OMPC_graph_id: case OMPC_novariants: case OMPC_nocontext: case OMPC_detach: @@ -19019,6 +19036,47 @@ OMPClause *SemaOpenMP::ActOnOpenMPDestroyClause(Expr *InteropVar, OMPDestroyClause(InteropVar, StartLoc, LParenLoc, VarLoc, EndLoc); } +OMPClause *SemaOpenMP::ActOnOpenMPGraphIdClause(Expr *Id, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + Expr *ValExpr = Id; + + if (!Id->isValueDependent() && !Id->isTypeDependent() && + !Id->isInstantiationDependent() && + !Id->containsUnexpandedParameterPack()) { + ExprResult Val = PerformOpenMPImplicitIntegerConversion(LParenLoc, Id); + if (Val.isInvalid()) + return nullptr; + + ValExpr = Val.get(); + } + + return new (getASTContext()) + OMPGraphIdClause(ValExpr, StartLoc, LParenLoc, EndLoc); +} + +OMPClause *SemaOpenMP::ActOnOpenMPGraphResetClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + Expr *ValExpr = Condition; + if (Condition && LParenLoc.isValid()) { + if (!Condition->isValueDependent() && !Condition->isTypeDependent() && + !Condition->isInstantiationDependent() && + !Condition->containsUnexpandedParameterPack()) { + ExprResult Val = SemaRef.CheckBooleanCondition(StartLoc, Condition); + if (Val.isInvalid()) + return nullptr; + + ValExpr = Val.get(); + } + } + + return new (getASTContext()) + OMPGraphResetClause(ValExpr, StartLoc, LParenLoc, EndLoc); +} + OMPClause *SemaOpenMP::ActOnOpenMPNovariantsClause(Expr *Condition, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -19319,6 +19377,8 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, case OMPC_severity: case OMPC_message: case OMPC_destroy: + case OMPC_graph_id: + case OMPC_graph_reset: case OMPC_novariants: case OMPC_nocontext: case OMPC_detach: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 40187f71231bd..f49b3eceace46 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2174,6 +2174,29 @@ class TreeTransform { LParenLoc, EndLoc); } + /// Build a new OpenMP 'graph_id' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPGraphIdClause(Expr *Condition, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().OpenMP().ActOnOpenMPGraphIdClause(Condition, StartLoc, + LParenLoc, EndLoc); + } + + /// Build a new OpenMP 'graph_reset' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPGraphResetClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().OpenMP().ActOnOpenMPGraphResetClause(Condition, StartLoc, + LParenLoc, EndLoc); + } + /// Build a new OpenMP 'grainsize' clause. /// /// By default, performs semantic analysis to build the new statement. @@ -11610,6 +11633,26 @@ TreeTransform::TransformOMPPriorityClause(OMPPriorityClause *C) { E.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); } +template +OMPClause * +TreeTransform::TransformOMPGraphIdClause(OMPGraphIdClause *C) { + ExprResult Cond = getDerived().TransformExpr(C->getId()); + if (Cond.isInvalid()) + return nullptr; + return getDerived().RebuildOMPGraphIdClause( + Cond.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); +} + +template +OMPClause * +TreeTransform::TransformOMPGraphResetClause(OMPGraphResetClause *C) { + ExprResult Cond = getDerived().TransformExpr(C->getCondition()); + if (Cond.isInvalid()) + return nullptr; + return getDerived().RebuildOMPGraphResetClause( + Cond.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); +} + template OMPClause * TreeTransform::TransformOMPGrainsizeClause(OMPGrainsizeClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index 7e8bb6509e84b..d98a5a3f240ef 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11700,6 +11700,12 @@ OMPClause *OMPClauseReader::readClause() { case llvm::omp::OMPC_grainsize: C = new (Context) OMPGrainsizeClause(); break; + case llvm::omp::OMPC_graph_id: + C = new (Context) OMPGraphIdClause(); + break; + case llvm::omp::OMPC_graph_reset: + C = new (Context) OMPGraphResetClause(); + break; case llvm::omp::OMPC_num_tasks: C = new (Context) OMPNumTasksClause(); break; @@ -12603,6 +12609,16 @@ void OMPClauseReader::VisitOMPPriorityClause(OMPPriorityClause *C) { C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPGraphIdClause(OMPGraphIdClause *C) { + C->setId(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + +void OMPClauseReader::VisitOMPGraphResetClause(OMPGraphResetClause *C) { + C->setCondition(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) { VisitOMPClauseWithPreInit(C); C->setModifier(Record.readEnum()); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index ba644fefc109a..df0d1a35ea715 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -8581,6 +8581,16 @@ void OMPClauseWriter::VisitOMPPriorityClause(OMPPriorityClause *C) { Record.AddSourceLocation(C->getLParenLoc()); } +void OMPClauseWriter::VisitOMPGraphIdClause(OMPGraphIdClause *C) { + Record.AddStmt(C->getId()); + Record.AddSourceLocation(C->getLParenLoc()); +} + +void OMPClauseWriter::VisitOMPGraphResetClause(OMPGraphResetClause *C) { + Record.AddStmt(C->getCondition()); + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) { VisitOMPClauseWithPreInit(C); Record.writeEnum(C->getModifier()); diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 350cd2135657d..1a92542c9595c 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2749,6 +2749,12 @@ void OMPClauseEnqueue::VisitOMPIsDevicePtrClause( const OMPIsDevicePtrClause *C) { VisitOMPClauseList(C); } +void OMPClauseEnqueue::VisitOMPGraphIdClause(const OMPGraphIdClause *C) { + Visitor->AddStmt(C->getId()); +} +void OMPClauseEnqueue::VisitOMPGraphResetClause(const OMPGraphResetClause *C) { + Visitor->AddStmt(C->getCondition()); +} void OMPClauseEnqueue::VisitOMPHasDeviceAddrClause( const OMPHasDeviceAddrClause *C) { VisitOMPClauseList(C); diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index e1e66df72dfc5..91ce4eac9e370 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -247,9 +247,11 @@ def OMPC_GrainSize : Clause<[Spelling<"grainsize">]> { ]; } def OMPC_GraphId : Clause<[Spelling<"graph_id">]> { + let clangClass = "OMPGraphIdClause"; let flangClass = "OmpGraphIdClause"; } def OMPC_GraphReset : Clause<[Spelling<"graph_reset">]> { + let clangClass = "OMPGraphResetClause"; let flangClass = "OmpGraphResetClause"; let isValueOptional = true; }