-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Clang][OpenMP] Add permutation clause #92030
base: users/meinersbur/clang_openmp_interchange
Are you sure you want to change the base?
[Clang][OpenMP] Add permutation clause #92030
Conversation
Buildkite failure is unrelated: |
@llvm/pr-subscribers-flang-parser @llvm/pr-subscribers-clang Author: Michael Kruse (Meinersbur) ChangesAdd the reverse and interchange directives which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in Technical Report 12. The boilerplate code of the new directives largely overlaps. Having both in a single PR should be easier to review and avoids confects when inevitably one is committed before the other. Patch is 561.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92030.diff 47 Files Affected:
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 365b607c74117..a79aafbf20222 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2146,6 +2146,14 @@ enum CXCursorKind {
*/
CXCursor_OMPScopeDirective = 306,
+ /** OpenMP reverse directive.
+ */
+ CXCursor_OMPReverseDirective = 307,
+
+ /** OpenMP interchange directive.
+ */
+ CXCursor_OMPInterchangeDirective = 308,
+
/** OpenACC Compute Construct.
*/
CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 325a1baa44614..c381bbde9eb11 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -870,6 +870,106 @@ class OMPSizesClause final
}
};
+/// This class represents the 'permutation' clause in the
+/// '#pragma omp interchange' directive.
+///
+/// \code{c}
+/// #pragma omp interchange permutation(2,1)
+/// for (int i = 0; i < 64; ++i)
+/// for (int j = 0; j < 64; ++j)
+/// \endcode
+class OMPPermutationClause final
+ : public OMPClause,
+ private llvm::TrailingObjects<OMPSizesClause, Expr *> {
+ friend class OMPClauseReader;
+ friend class llvm::TrailingObjects<OMPSizesClause, Expr *>;
+
+ /// Location of '('.
+ SourceLocation LParenLoc;
+
+ /// Number of arguments in the clause, and hence also the number of loops to
+ /// be permuted.
+ unsigned NumLoops;
+
+ /// Build an empty clause.
+ explicit OMPPermutationClause(int NumLoops)
+ : OMPClause(llvm::omp::OMPC_permutation, SourceLocation(),
+ SourceLocation()),
+ NumLoops(NumLoops) {}
+
+public:
+ /// Build a 'permutation' clause AST node.
+ ///
+ /// \param C Context of the AST.
+ /// \param StartLoc Location of the 'permutation' identifier.
+ /// \param LParenLoc Location of '('.
+ /// \param EndLoc Location of ')'.
+ /// \param Args Content of the clause.
+ static OMPPermutationClause *
+ Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
+ SourceLocation EndLoc, ArrayRef<Expr *> Args);
+
+ /// Build an empty 'permutation' AST node for deserialization.
+ ///
+ /// \param C Context of the AST.
+ /// \param NumLoops Number of arguments in the clause.
+ static OMPPermutationClause *CreateEmpty(const ASTContext &C,
+ unsigned NumLoops);
+
+ /// Sets the location of '('.
+ void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+
+ /// Returns the location of '('.
+ SourceLocation getLParenLoc() const { return LParenLoc; }
+
+ /// Returns the number of list items.
+ unsigned getNumLoops() const { return NumLoops; }
+
+ /// Returns the permutation index expressions.
+ ///@{
+ MutableArrayRef<Expr *> getArgsRefs() {
+ return MutableArrayRef<Expr *>(static_cast<OMPPermutationClause *>(this)
+ ->template getTrailingObjects<Expr *>(),
+ NumLoops);
+ }
+ ArrayRef<Expr *> getArgsRefs() const {
+ return ArrayRef<Expr *>(static_cast<const OMPPermutationClause *>(this)
+ ->template getTrailingObjects<Expr *>(),
+ NumLoops);
+ }
+ ///@}
+
+ /// Sets the permutation index expressions.
+ void setArgRefs(ArrayRef<Expr *> VL) {
+ assert(VL.size() == NumLoops);
+ std::copy(VL.begin(), VL.end(),
+ static_cast<OMPPermutationClause *>(this)
+ ->template getTrailingObjects<Expr *>());
+ }
+
+ child_range children() {
+ MutableArrayRef<Expr *> Args = getArgsRefs();
+ return child_range(reinterpret_cast<Stmt **>(Args.begin()),
+ reinterpret_cast<Stmt **>(Args.end()));
+ }
+ const_child_range children() const {
+ ArrayRef<Expr *> Args = getArgsRefs();
+ return const_child_range(reinterpret_cast<Stmt *const *>(Args.begin()),
+ reinterpret_cast<Stmt *const *>(Args.end()));
+ }
+
+ child_range used_children() {
+ return child_range(child_iterator(), child_iterator());
+ }
+ const_child_range used_children() const {
+ return const_child_range(const_child_iterator(), const_child_iterator());
+ }
+
+ static bool classof(const OMPClause *T) {
+ return T->getClauseKind() == llvm::omp::OMPC_permutation;
+ }
+};
+
/// Representation of the 'full' clause of the '#pragma omp unroll' directive.
///
/// \code
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index f9b145b4e86a5..7596b78c94f1c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3004,6 +3004,12 @@ DEF_TRAVERSE_STMT(OMPTileDirective,
DEF_TRAVERSE_STMT(OMPUnrollDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
+DEF_TRAVERSE_STMT(OMPReverseDirective,
+ { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
+DEF_TRAVERSE_STMT(OMPInterchangeDirective,
+ { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
DEF_TRAVERSE_STMT(OMPForDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })
@@ -3302,6 +3308,14 @@ bool RecursiveASTVisitor<Derived>::VisitOMPSizesClause(OMPSizesClause *C) {
return true;
}
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPPermutationClause(
+ OMPPermutationClause *C) {
+ for (Expr *E : C->getArgsRefs())
+ TRY_TO(TraverseStmt(E));
+ return true;
+}
+
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPFullClause(OMPFullClause *C) {
return true;
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index f735fa5643aec..df6c485117c80 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1007,8 +1007,9 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
Stmt *getPreInits() const;
static bool classof(const Stmt *T) {
- return T->getStmtClass() == OMPTileDirectiveClass ||
- T->getStmtClass() == OMPUnrollDirectiveClass;
+ Stmt::StmtClass C = T->getStmtClass();
+ return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
+ C == OMPReverseDirectiveClass || C == OMPInterchangeDirectiveClass;
}
};
@@ -5638,6 +5639,147 @@ class OMPTileDirective final : public OMPLoopTransformationDirective {
}
};
+/// Represents the '#pragma omp reverse' loop transformation directive.
+///
+/// \code
+/// #pragma omp reverse
+/// for (int i = 0; i < n; ++i)
+/// ...
+/// \endcode
+class OMPReverseDirective final : public OMPLoopTransformationDirective {
+ friend class ASTStmtReader;
+ friend class OMPExecutableDirective;
+
+ /// Offsets of child members.
+ enum {
+ PreInitsOffset = 0,
+ TransformedStmtOffset,
+ };
+
+ explicit OMPReverseDirective(SourceLocation StartLoc, SourceLocation EndLoc)
+ : OMPLoopTransformationDirective(OMPReverseDirectiveClass,
+ llvm::omp::OMPD_reverse, StartLoc,
+ EndLoc, 1) {}
+
+ void setPreInits(Stmt *PreInits) {
+ Data->getChildren()[PreInitsOffset] = PreInits;
+ }
+
+ void setTransformedStmt(Stmt *S) {
+ Data->getChildren()[TransformedStmtOffset] = S;
+ }
+
+public:
+ /// Create a new AST node representation for '#pragma omp reverse'.
+ ///
+ /// \param C Context of the AST.
+ /// \param StartLoc Location of the introducer (e.g. the 'omp' token).
+ /// \param EndLoc Location of the directive's end (e.g. the tok::eod).
+ /// \param Clauses The directive's clauses.
+ /// \param AssociatedStmt The outermost associated loop.
+ /// \param TransformedStmt The loop nest after tiling, or nullptr in
+ /// dependent contexts.
+ /// \param PreInits Helper preinits statements for the loop nest.
+ static OMPReverseDirective *
+ Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits);
+
+ /// Build an empty '#pragma omp reverse' AST node for deserialization.
+ ///
+ /// \param C Context of the AST.
+ /// \param NumClauses Number of clauses to allocate.
+ static OMPReverseDirective *CreateEmpty(const ASTContext &C,
+ unsigned NumClauses);
+
+ /// Gets/sets the associated loops after the transformation, i.e. after
+ /// de-sugaring.
+ Stmt *getTransformedStmt() const {
+ return Data->getChildren()[TransformedStmtOffset];
+ }
+
+ /// Return preinits statement.
+ Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == OMPReverseDirectiveClass;
+ }
+};
+
+/// Represents the '#pragma omp interchange' loop transformation directive.
+///
+/// \code{c}
+/// #pragma omp interchange
+/// for (int i = 0; i < m; ++i)
+/// for (int j = 0; j < n; ++j)
+/// ..
+/// \endcode
+class OMPInterchangeDirective final : public OMPLoopTransformationDirective {
+ friend class ASTStmtReader;
+ friend class OMPExecutableDirective;
+
+ /// Offsets of child members.
+ enum {
+ PreInitsOffset = 0,
+ TransformedStmtOffset,
+ };
+
+ explicit OMPInterchangeDirective(SourceLocation StartLoc,
+ SourceLocation EndLoc, unsigned NumLoops)
+ : OMPLoopTransformationDirective(OMPInterchangeDirectiveClass,
+ llvm::omp::OMPD_interchange, StartLoc,
+ EndLoc, NumLoops) {
+ setNumGeneratedLoops(3 * NumLoops);
+ }
+
+ void setPreInits(Stmt *PreInits) {
+ Data->getChildren()[PreInitsOffset] = PreInits;
+ }
+
+ void setTransformedStmt(Stmt *S) {
+ Data->getChildren()[TransformedStmtOffset] = S;
+ }
+
+public:
+ /// Create a new AST node representation for '#pragma omp interchange'.
+ ///
+ /// \param C Context of the AST.
+ /// \param StartLoc Location of the introducer (e.g. the 'omp' token).
+ /// \param EndLoc Location of the directive's end (e.g. the tok::eod).
+ /// \param Clauses The directive's clauses.
+ /// \param NumLoops Number of affected loops
+ /// (number of items in the 'permutation' clause if present).
+ /// \param AssociatedStmt The outermost associated loop.
+ /// \param TransformedStmt The loop nest after tiling, or nullptr in
+ /// dependent contexts.
+ /// \param PreInits Helper preinits statements for the loop nest.
+ static OMPInterchangeDirective *
+ Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits);
+
+ /// Build an empty '#pragma omp interchange' AST node for deserialization.
+ ///
+ /// \param C Context of the AST.
+ /// \param NumClauses Number of clauses to allocate.
+ /// \param NumLoops Number of associated loops to allocate.
+ static OMPInterchangeDirective *
+ CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned NumLoops);
+
+ /// Gets the associated loops after the transformation. This is the de-sugared
+ /// replacement or nullptr in dependent contexts.
+ Stmt *getTransformedStmt() const {
+ return Data->getChildren()[TransformedStmtOffset];
+ }
+
+ /// Return preinits statement.
+ Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == OMPInterchangeDirectiveClass;
+ }
+};
+
/// This represents the '#pragma omp unroll' loop transformation directive.
///
/// \code
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 9e82130c93609..690912060f723 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11494,6 +11494,10 @@ def err_omp_dispatch_statement_call
" to a target function or an assignment to one">;
def err_omp_unroll_full_variable_trip_count : Error<
"loop to be fully unrolled must have a constant trip count">;
+def err_omp_interchange_permutation_value_range : Error<
+ "permutation index must be at least 1 and at most %0">;
+def err_omp_interchange_permutation_value_repeated : Error<
+ "index %0 must appear exactly once in the permutation clause">;
def note_omp_directive_here : Note<"'%0' directive found here">;
def err_omp_instantiation_not_supported
: Error<"instantiation of '%0' not supported yet">;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 305f19daa4a92..b445ea225eac5 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -229,6 +229,8 @@ def OMPSimdDirective : StmtNode<OMPLoopDirective>;
def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPForDirective : StmtNode<OMPLoopDirective>;
def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 61589fb7766f4..aea53673721bd 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3537,6 +3537,9 @@ class Parser : public CodeCompletionHandler {
/// Parses the 'sizes' clause of a '#pragma omp tile' directive.
OMPClause *ParseOpenMPSizesClause();
+ /// Parses the 'permutation' clause of a '#pragma omp interchange' directive.
+ OMPClause *ParseOpenMPPermutationClause();
+
/// Parses clause without any additional arguments.
///
/// \param Kind Kind of current clause.
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 51981e1c9a8b9..fe9d702b7af03 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -422,6 +422,17 @@ class SemaOpenMP : public SemaBase {
StmtResult ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt, SourceLocation StartLoc,
SourceLocation EndLoc);
+ /// Called on well-formed '#pragma omp reverse' after parsing of its clauses
+ /// and the associated statement.
+ StmtResult ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt, SourceLocation StartLoc,
+ SourceLocation EndLoc);
+ /// Called on well-formed '#pragma omp interchange' after parsing of its
+ /// clauses and the associated statement.
+ StmtResult ActOnOpenMPInterchangeDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt,
+ SourceLocation StartLoc,
+ SourceLocation EndLoc);
/// Called on well-formed '\#pragma omp for' after parsing
/// of the associated statement.
StmtResult
@@ -859,6 +870,11 @@ class SemaOpenMP : public SemaBase {
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
+ /// Called on well-form 'permutation' clause after parsing its arguments.
+ OMPClause *ActOnOpenMPPermutationClause(ArrayRef<Expr *> PermExprs,
+ SourceLocation StartLoc,
+ SourceLocation LParenLoc,
+ SourceLocation EndLoc);
/// Called on well-form 'full' clauses.
OMPClause *ActOnOpenMPFullClause(SourceLocation StartLoc,
SourceLocation EndLoc);
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index d3538e43d3d78..f488d1eca35c2 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1856,6 +1856,8 @@ enum StmtCode {
STMT_OMP_SIMD_DIRECTIVE,
STMT_OMP_TILE_DIRECTIVE,
STMT_OMP_UNROLL_DIRECTIVE,
+ STMT_OMP_REVERSE_DIRECTIVE,
+ STMT_OMP_INTERCHANGE_DIRECTIVE,
STMT_OMP_FOR_DIRECTIVE,
STMT_OMP_FOR_SIMD_DIRECTIVE,
STMT_OMP_SECTIONS_DIRECTIVE,
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 042a5df5906ca..c94310769ed9e 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -971,6 +971,25 @@ OMPSizesClause *OMPSizesClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPSizesClause(NumSizes);
}
+OMPPermutationClause *OMPPermutationClause::Create(const ASTContext &C,
+ SourceLocation StartLoc,
+ SourceLocation LParenLoc,
+ SourceLocation EndLoc,
+ ArrayRef<Expr *> Args) {
+ OMPPermutationClause *Clause = CreateEmpty(C, Args.size());
+ Clause->setLocStart(StartLoc);
+ Clause->setLParenLoc(LParenLoc);
+ Clause->setLocEnd(EndLoc);
+ Clause->setArgRefs(Args);
+ return Clause;
+}
+
+OMPPermutationClause *OMPPermutationClause::CreateEmpty(const ASTContext &C,
+ unsigned NumLoops) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(NumLoops));
+ return new (Mem) OMPPermutationClause(NumLoops);
+}
+
OMPFullClause *OMPFullClause::Create(const ASTContext &C,
SourceLocation StartLoc,
SourceLocation EndLoc) {
@@ -1774,6 +1793,18 @@ void OMPClausePrinter::VisitOMPSizesClause(OMPSizesClause *Node) {
OS << ")";
}
+void OMPClausePrinter::VisitOMPPermutationClause(OMPPermutationClause *Node) {
+ OS << "permutation(";
+ bool First = true;
+ for (Expr *Size : Node->getArgsRefs()) {
+ if (!First)
+ OS << ", ";
+ Size->printPretty(OS, nullptr, Policy, 0);
+ First = false;
+ }
+ OS << ")";
+}
+
void OMPClausePrinter::VisitOMPFullClause(OMPFullClause *Node) { OS << "full"; }
void OMPClausePrinter::VisitOMPPartialClause(OMPPartialClause *Node) {
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index d8519b2071e6d..8a4a736263920 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -449,6 +449,45 @@ OMPUnrollDirective *OMPUnrollDirective::CreateEmpty(const ASTContext &C,
SourceLocation(), SourceLocation());
}
+OMPReverseDirective *
+OMPReverseDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits) {
+ OMPReverseDirective *Dir = createDirective<OMPReverseDirective>(
+ C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
+ Dir->setTransformedStmt(TransformedStmt);
+ Dir->setPreInits(PreInits);
+ return Dir;
+}
+
+OMPReverseDirective *OMPReverseDirective::CreateEmpty(const ASTContext &C,
+ unsigned NumClauses) {
+ return createEmptyDirective<OMPReverseDirective>(
+ C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+ SourceLocation(), SourceLocation());
+}
+
+OMPInterchangeDirective *OMPInterchangeDirective::Create(
+ const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+ ArrayRef<OMPClause *> Clauses, unsigned NumLoops, Stmt *AssociatedStmt,
+ Stmt *TransformedStmt, Stmt *PreInits) {
+ OMPInterchangeDirective *Dir = createDirective<OMPInterchangeDirective>(
+ C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLo...
[truncated]
|
b8b1260
to
d6057c4
Compare
ffc8547
to
5ac2385
Compare
Extracted out the reverse part into #92916 |
5ac2385
to
fa15956
Compare
Add the permutation clause for the interchange directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in Technical Report 12.