Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenACC] Implement 'num_gangs' sema for compute constructs #89460

Merged
merged 2 commits into from
Apr 22, 2024

Conversation

erichkeane
Copy link
Collaborator

num_gangs takes an 'int-expr-list', for 'parallel', and an 'int-expr' for 'kernels'. This patch changes the parsing to always parse it as an 'int-expr-list', then correct the expression count during Sema. It also implements the rest of the semantic analysis changes for this clause.

num_gangs takes an 'int-expr-list', for 'parallel', and an 'int-expr' for
'kernels'.  This patch changes the parsing to always parse it as an
'int-expr-list', then correct the expression count during Sema.  It also
implements the rest of the semantic analysis changes for this clause.
@erichkeane erichkeane added the clang:openacc Clang OpenACC Implementation label Apr 19, 2024
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules labels Apr 19, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-clang-modules

Author: Erich Keane (erichkeane)

Changes

num_gangs takes an 'int-expr-list', for 'parallel', and an 'int-expr' for 'kernels'. This patch changes the parsing to always parse it as an 'int-expr-list', then correct the expression count during Sema. It also implements the rest of the semantic analysis changes for this clause.


Patch is 37.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89460.diff

18 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+71-16)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+5)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Parse/Parser.h (+14-2)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+13-3)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+16)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+1)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+70-17)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+35)
  • (modified) clang/lib/Sema/TreeTransform.h (+26)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+9-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+8-1)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-intexpr-clause-ast.cpp (+79)
  • (added) clang/test/SemaOpenACC/compute-construct-num_gangs-clause.c (+54)
  • (added) clang/test/SemaOpenACC/compute-construct-num_gangs-clause.cpp (+160)
  • (modified) clang/tools/libclang/CIndex.cpp (+4)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 8b6d3221aa066b..277a351c49fcb8 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -156,33 +156,88 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
                                    Expr *ConditionExpr, SourceLocation EndLoc);
 };
 
-/// Represents oen of a handful of classes that have a single integer
+/// Represents a clause that has one or more IntExprs.  It does not own the
+/// IntExprs, but provides 'children' and other accessors.
+class OpenACCClauseWithIntExprs : public OpenACCClauseWithParams {
+  MutableArrayRef<Expr *> IntExprs;
+
+protected:
+  OpenACCClauseWithIntExprs(OpenACCClauseKind K, SourceLocation BeginLoc,
+                            SourceLocation LParenLoc, SourceLocation EndLoc)
+      : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc) {}
+
+  /// Used only for initialization, the leaf class can initialize this to
+  /// trailing storage.
+  void setIntExprs(MutableArrayRef<Expr *> NewIntExprs) {
+    assert(IntExprs.empty() && "Cannot change IntExprs list");
+    IntExprs = NewIntExprs;
+  }
+
+  /// Gets the entire list of integer expressions, but leave it to the
+  /// individual clauses to expose this how they'd like.
+  llvm::ArrayRef<Expr *> getIntExprs() const { return IntExprs; }
+
+public:
+  child_range children() {
+    return child_range(reinterpret_cast<Stmt **>(IntExprs.begin()),
+                       reinterpret_cast<Stmt **>(IntExprs.end()));
+  }
+
+  const_child_range children() const {
+    child_range Children =
+        const_cast<OpenACCClauseWithIntExprs *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
+  }
+};
+
+class OpenACCNumGangsClause final
+    : public OpenACCClauseWithIntExprs,
+      public llvm::TrailingObjects<OpenACCNumGangsClause, Expr *> {
+
+  OpenACCNumGangsClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                        ArrayRef<Expr *> IntExprs, SourceLocation EndLoc)
+      : OpenACCClauseWithIntExprs(OpenACCClauseKind::NumGangs, BeginLoc,
+                                  LParenLoc, EndLoc) {
+    std::uninitialized_copy(IntExprs.begin(), IntExprs.end(),
+                            getTrailingObjects<Expr *>());
+    setIntExprs(MutableArrayRef(getTrailingObjects<Expr *>(), IntExprs.size()));
+  }
+
+public:
+  static OpenACCNumGangsClause *
+  Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+         ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
+
+  llvm::ArrayRef<Expr *> getIntExprs() {
+    return OpenACCClauseWithIntExprs::getIntExprs();
+  }
+
+  llvm::ArrayRef<Expr *> getIntExprs() const {
+    return OpenACCClauseWithIntExprs::getIntExprs();
+  }
+};
+
+/// Represents one of a handful of clauses that have a single integer
 /// expression.
-class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithParams {
+class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithIntExprs {
   Expr *IntExpr;
 
 protected:
   OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
                                  SourceLocation LParenLoc, Expr *IntExpr,
                                  SourceLocation EndLoc)
-      : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
-        IntExpr(IntExpr) {}
+      : OpenACCClauseWithIntExprs(K, BeginLoc, LParenLoc, EndLoc),
+        IntExpr(IntExpr) {
+    setIntExprs(MutableArrayRef<Expr *>{&this->IntExpr, 1});
+  }
 
 public:
-  bool hasIntExpr() const { return IntExpr; }
-  const Expr *getIntExpr() const { return IntExpr; }
-
-  Expr *getIntExpr() { return IntExpr; };
-
-  child_range children() {
-    return child_range(reinterpret_cast<Stmt **>(&IntExpr),
-                       reinterpret_cast<Stmt **>(&IntExpr + 1));
+  bool hasIntExpr() const { return !getIntExprs().empty(); }
+  const Expr *getIntExpr() const {
+    return hasIntExpr() ? getIntExprs()[0] : nullptr;
   }
 
-  const_child_range children() const {
-    return const_child_range(reinterpret_cast<Stmt *const *>(&IntExpr),
-                             reinterpret_cast<Stmt *const *>(&IntExpr + 1));
-  }
+  Expr *getIntExpr() { return hasIntExpr() ? getIntExprs()[0] : nullptr; };
 };
 
 class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1bbe76ff6bd2ac..a95424862e63f4 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12293,4 +12293,9 @@ def note_acc_int_expr_conversion
     : Note<"conversion to %select{integral|enumeration}0 type %1">;
 def err_acc_int_expr_multiple_conversions
     : Error<"multiple conversions from expression type %0 to an integral type">;
+def err_acc_num_gangs_num_args
+    : Error<"%select{no|too many}0 integer expression arguments provided to "
+            "OpenACC 'num_gangs' "
+            "%select{|clause: '%1' directive expects maximum of %2, %3 were "
+            "provided}0">;
 } // end of sema component.
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 520e068f4ffd40..dd5792e7ca8c39 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -18,6 +18,7 @@
 VISIT_CLAUSE(Default)
 VISIT_CLAUSE(If)
 VISIT_CLAUSE(Self)
+VISIT_CLAUSE(NumGangs)
 VISIT_CLAUSE(NumWorkers)
 VISIT_CLAUSE(VectorLength)
 
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 72b2f958a5e622..7e95a52e6f34d5 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3644,10 +3644,22 @@ class Parser : public CodeCompletionHandler {
   /// Parses the clause of the 'bind' argument, which can be a string literal or
   /// an ID expression.
   ExprResult ParseOpenACCBindClauseArgument();
+
+  /// A type to represent the state of parsing after an attempt to parse an
+  /// OpenACC int-expr. This is useful to determine whether an int-expr list can
+  /// continue parsing after a failed int-expr.
+  using OpenACCIntExprParseResult =
+      std::pair<ExprResult, OpenACCParseCanContinue>;
   /// Parses the clause kind of 'int-expr', which can be any integral
   /// expression.
-  ExprResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
-                                 SourceLocation Loc);
+  OpenACCIntExprParseResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
+                                                OpenACCClauseKind CK,
+                                                SourceLocation Loc);
+  /// Parses the argument list for 'num_gangs', which allows up to 3
+  /// 'int-expr's.
+  bool ParseOpenACCIntExprList(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
+                               SourceLocation Loc,
+                               llvm::SmallVector<Expr *> &IntExprs);
   /// Parses the 'device-type-list', which is a list of identifiers.
   bool ParseOpenACCDeviceTypeList();
   /// Parses the 'async-argument', which is an integral value with two
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 023722049732af..ea28617f79b81b 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -93,14 +93,16 @@ class SemaOpenACC : public SemaBase {
     }
 
     unsigned getNumIntExprs() const {
-      assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       return std::get<IntExprDetails>(Details).IntExprs.size();
     }
 
     ArrayRef<Expr *> getIntExprs() {
-      assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       return std::get<IntExprDetails>(Details).IntExprs;
@@ -134,11 +136,19 @@ class SemaOpenACC : public SemaBase {
     }
 
     void setIntExprDetails(ArrayRef<Expr *> IntExprs) {
-      assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
     }
+    void setIntExprDetails(llvm::SmallVector<Expr *> &&IntExprs) {
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
+              ClauseKind == OpenACCClauseKind::VectorLength) &&
+             "Parsed clause kind does not have a int exprs");
+      Details = IntExprDetails{IntExprs};
+    }
   };
 
   SemaOpenACC(Sema &S);
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 75334223e073c3..6cd5b28802187d 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -124,6 +124,16 @@ OpenACCVectorLengthClause::Create(const ASTContext &C, SourceLocation BeginLoc,
       OpenACCVectorLengthClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
 }
 
+OpenACCNumGangsClause *OpenACCNumGangsClause::Create(const ASTContext &C,
+                                                     SourceLocation BeginLoc,
+                                                     SourceLocation LParenLoc,
+                                                     ArrayRef<Expr *> IntExprs,
+                                                     SourceLocation EndLoc) {
+  void *Mem = C.Allocate(
+      OpenACCNumGangsClause::totalSizeToAlloc<Expr *>(IntExprs.size()));
+  return new (Mem) OpenACCNumGangsClause(BeginLoc, LParenLoc, IntExprs, EndLoc);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenACC clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -141,6 +151,12 @@ void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
     OS << "(" << CondExpr << ")";
 }
 
+void OpenACCClausePrinter::VisitNumGangsClause(const OpenACCNumGangsClause &C) {
+  OS << "num_gangs(";
+  llvm::interleaveComma(C.getIntExprs(), OS);
+  OS << ")";
+}
+
 void OpenACCClausePrinter::VisitNumWorkersClause(
     const OpenACCNumWorkersClause &C) {
   OS << "num_workers(" << C.getIntExpr() << ")";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 8138ae3a244a8b..c81724f84dd9ce 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2497,6 +2497,12 @@ void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
     Profiler.VisitStmt(Clause.getConditionExpr());
 }
 
+void OpenACCClauseProfiler::VisitNumGangsClause(
+    const OpenACCNumGangsClause &Clause) {
+  for (auto *E : Clause.getIntExprs())
+    Profiler.VisitStmt(E);
+}
+
 void OpenACCClauseProfiler::VisitNumWorkersClause(
     const OpenACCNumWorkersClause &Clause) {
   assert(Clause.hasIntExpr() && "num_workers clause requires a valid int expr");
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index e5a8b285715b30..8f0a9a9b0ed0bc 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -399,6 +399,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
       break;
     case OpenACCClauseKind::If:
     case OpenACCClauseKind::Self:
+    case OpenACCClauseKind::NumGangs:
     case OpenACCClauseKind::NumWorkers:
     case OpenACCClauseKind::VectorLength:
       // The condition expression will be printed as a part of the 'children',
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 757417f75c9636..9b7cd1668f7c17 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -632,16 +632,54 @@ Parser::ParseOpenACCClauseList(OpenACCDirectiveKind DirKind) {
   return Clauses;
 }
 
-ExprResult Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
-                                       OpenACCClauseKind CK,
-                                       SourceLocation Loc) {
-  ExprResult ER =
-      getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
+Parser::OpenACCIntExprParseResult
+Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
+                            SourceLocation Loc) {
+  ExprResult ER = ParseAssignmentExpression();
 
+  // If the actual parsing failed, we don't know the state of the parse, so
+  // don't try to continue.
   if (!ER.isUsable())
-    return ER;
+    return {ER, OpenACCParseCanContinue::Cannot};
+
+  // Parsing can continue after the initial assignment expression parsing, so
+  // even if there was a typo, we can continue.
+  ER = getActions().CorrectDelayedTyposInExpr(ER);
+  if (!ER.isUsable())
+    return {ER, OpenACCParseCanContinue::Can};
+
+  return {getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get()),
+          OpenACCParseCanContinue::Can};
+}
+
+bool Parser::ParseOpenACCIntExprList(OpenACCDirectiveKind DK,
+                                     OpenACCClauseKind CK, SourceLocation Loc,
+                                     llvm::SmallVector<Expr *> &IntExprs) {
+  OpenACCIntExprParseResult CurResult = ParseOpenACCIntExpr(DK, CK, Loc);
+
+  if (!CurResult.first.isUsable() &&
+      CurResult.second == OpenACCParseCanContinue::Cannot) {
+    SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
+              Parser::StopBeforeMatch);
+    return true;
+  }
+
+  IntExprs.push_back(CurResult.first.get());
+
+  while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
+    ExpectAndConsume(tok::comma);
+
+    CurResult = ParseOpenACCIntExpr(DK, CK, Loc);
 
-  return getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get());
+    if (!CurResult.first.isUsable() &&
+        CurResult.second == OpenACCParseCanContinue::Cannot) {
+      SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
+                Parser::StopBeforeMatch);
+      return true;
+    }
+    IntExprs.push_back(CurResult.first.get());
+  }
+  return false;
 }
 
 bool Parser::ParseOpenACCClauseVarList(OpenACCClauseKind Kind) {
@@ -761,7 +799,7 @@ bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
     ConsumeToken();
     return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
                                OpenACCClauseKind::Gang, GangLoc)
-        .isInvalid();
+        .first.isInvalid();
   }
 
   if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Num, getCurToken()) &&
@@ -773,7 +811,7 @@ bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
   // This is just the 'num' case where 'num' is optional.
   return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
                              OpenACCClauseKind::Gang, GangLoc)
-      .isInvalid();
+      .first.isInvalid();
 }
 
 bool Parser::ParseOpenACCGangArgList(SourceLocation GangLoc) {
@@ -946,13 +984,25 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
       }
       break;
     }
-    case OpenACCClauseKind::NumGangs:
+    case OpenACCClauseKind::NumGangs: {
+      llvm::SmallVector<Expr *> IntExprs;
+
+      if (ParseOpenACCIntExprList(OpenACCDirectiveKind::Invalid,
+                                  OpenACCClauseKind::NumGangs, ClauseLoc,
+                                  IntExprs)) {
+        Parens.skipToEnd();
+        return OpenACCCanContinue();
+      }
+      ParsedClause.setIntExprDetails(std::move(IntExprs));
+      break;
+    }
     case OpenACCClauseKind::NumWorkers:
     case OpenACCClauseKind::DeviceNum:
     case OpenACCClauseKind::DefaultAsync:
     case OpenACCClauseKind::VectorLength: {
       ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
-                                               ClauseKind, ClauseLoc);
+                                               ClauseKind, ClauseLoc)
+                               .first;
       if (IntExpr.isInvalid()) {
         Parens.skipToEnd();
         return OpenACCCanContinue();
@@ -1017,7 +1067,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
                                                : OpenACCSpecialTokenKind::Num,
                                            ClauseKind);
         ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
-                                                 ClauseKind, ClauseLoc);
+                                                 ClauseKind, ClauseLoc)
+                                 .first;
         if (IntExpr.isInvalid()) {
           Parens.skipToEnd();
           return OpenACCCanContinue();
@@ -1081,11 +1132,13 @@ bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
     // Consume colon.
     ConsumeToken();
 
-    ExprResult IntExpr = ParseOpenACCIntExpr(
-        IsDirective ? OpenACCDirectiveKind::Wait
-                    : OpenACCDirectiveKind::Invalid,
-        IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
-        Loc);
+    ExprResult IntExpr =
+        ParseOpenACCIntExpr(IsDirective ? OpenACCDirectiveKind::Wait
+                                        : OpenACCDirectiveKind::Invalid,
+                            IsDirective ? OpenACCClauseKind::Invalid
+                                        : OpenACCClauseKind::Wait,
+                            Loc)
+            .first;
     if (IntExpr.isInvalid())
       return true;
 
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index 190739fa02e932..ba69e71e30a181 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -91,6 +91,7 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
     default:
       return false;
     }
+  case OpenACCClauseKind::NumGangs:
   case OpenACCClauseKind::NumWorkers:
   case OpenACCClauseKind::VectorLength:
     switch (DirectiveKind) {
@@ -230,6 +231,40 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
         getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
         Clause.getConditionExpr(), Clause.getEndLoc());
   }
+  case OpenACCClauseKind::NumGangs: {
+    // Restrictions only properly implemented on 'compute' constructs, and
+    // 'compute' constructs are the only construct that can do anything with
+    // this yet, so skip/treat as unimplemented in this case.
+    if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
+      break;
+
+    // There is no prose in the standard that says duplicates aren't allowed,
+    // but this diagnostic is present in other compilers, as well as makes
+    // sense.
+    if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
+      return nullptr;
+
+    if (Clause.getIntExprs().empty())
+      Diag(Clause.getBeginLoc(), diag::err_acc_num_gangs_num_args)
+          << /*NoArgs=*/0;
+
+    unsigned MaxArgs =
+        (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel ||
+         Clause.getDirectiveKind() == OpenACCDirectiveKind::ParallelLoop)
+            ? 3
+            : 1;
+    if (Clause.getIntExprs().size() > MaxArgs)
+      Diag(Clause.getBeginLoc(), diag::err_acc_num_gangs_num_args)
+          << /*NoArgs=*/1 << Clause.getDirectiveKind() << MaxArgs
+          << Clause.getIntExprs().size();
+
+    // Create the AST node for the clause even if the number of expressions is
+    // incorrect.
+    return OpenACCNumGangsClause::Create(
+        getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
+        Clause.getIntExprs(), Clause.getEndLoc());
+    break;
+  }
   case OpenACCClauseKind::NumWorkers: {
     // Restrictions only properly implemented on 'compute' constructs, and
     // 'compute' constructs are the only construct that can do anything with
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index ade33ec65038fd..10ce226...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-clang

Author: Erich Keane (erichkeane)

Changes

num_gangs takes an 'int-expr-list', for 'parallel', and an 'int-expr' for 'kernels'. This patch changes the parsing to always parse it as an 'int-expr-list', then correct the expression count during Sema. It also implements the rest of the semantic analysis changes for this clause.


Patch is 37.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89460.diff

18 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+71-16)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+5)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Parse/Parser.h (+14-2)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+13-3)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+16)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+1)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+70-17)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+35)
  • (modified) clang/lib/Sema/TreeTransform.h (+26)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+9-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+8-1)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (-4)
  • (modified) clang/test/SemaOpenACC/compute-construct-intexpr-clause-ast.cpp (+79)
  • (added) clang/test/SemaOpenACC/compute-construct-num_gangs-clause.c (+54)
  • (added) clang/test/SemaOpenACC/compute-construct-num_gangs-clause.cpp (+160)
  • (modified) clang/tools/libclang/CIndex.cpp (+4)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 8b6d3221aa066b..277a351c49fcb8 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -156,33 +156,88 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
                                    Expr *ConditionExpr, SourceLocation EndLoc);
 };
 
-/// Represents oen of a handful of classes that have a single integer
+/// Represents a clause that has one or more IntExprs.  It does not own the
+/// IntExprs, but provides 'children' and other accessors.
+class OpenACCClauseWithIntExprs : public OpenACCClauseWithParams {
+  MutableArrayRef<Expr *> IntExprs;
+
+protected:
+  OpenACCClauseWithIntExprs(OpenACCClauseKind K, SourceLocation BeginLoc,
+                            SourceLocation LParenLoc, SourceLocation EndLoc)
+      : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc) {}
+
+  /// Used only for initialization, the leaf class can initialize this to
+  /// trailing storage.
+  void setIntExprs(MutableArrayRef<Expr *> NewIntExprs) {
+    assert(IntExprs.empty() && "Cannot change IntExprs list");
+    IntExprs = NewIntExprs;
+  }
+
+  /// Gets the entire list of integer expressions, but leave it to the
+  /// individual clauses to expose this how they'd like.
+  llvm::ArrayRef<Expr *> getIntExprs() const { return IntExprs; }
+
+public:
+  child_range children() {
+    return child_range(reinterpret_cast<Stmt **>(IntExprs.begin()),
+                       reinterpret_cast<Stmt **>(IntExprs.end()));
+  }
+
+  const_child_range children() const {
+    child_range Children =
+        const_cast<OpenACCClauseWithIntExprs *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
+  }
+};
+
+class OpenACCNumGangsClause final
+    : public OpenACCClauseWithIntExprs,
+      public llvm::TrailingObjects<OpenACCNumGangsClause, Expr *> {
+
+  OpenACCNumGangsClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                        ArrayRef<Expr *> IntExprs, SourceLocation EndLoc)
+      : OpenACCClauseWithIntExprs(OpenACCClauseKind::NumGangs, BeginLoc,
+                                  LParenLoc, EndLoc) {
+    std::uninitialized_copy(IntExprs.begin(), IntExprs.end(),
+                            getTrailingObjects<Expr *>());
+    setIntExprs(MutableArrayRef(getTrailingObjects<Expr *>(), IntExprs.size()));
+  }
+
+public:
+  static OpenACCNumGangsClause *
+  Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
+         ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
+
+  llvm::ArrayRef<Expr *> getIntExprs() {
+    return OpenACCClauseWithIntExprs::getIntExprs();
+  }
+
+  llvm::ArrayRef<Expr *> getIntExprs() const {
+    return OpenACCClauseWithIntExprs::getIntExprs();
+  }
+};
+
+/// Represents one of a handful of clauses that have a single integer
 /// expression.
-class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithParams {
+class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithIntExprs {
   Expr *IntExpr;
 
 protected:
   OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
                                  SourceLocation LParenLoc, Expr *IntExpr,
                                  SourceLocation EndLoc)
-      : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
-        IntExpr(IntExpr) {}
+      : OpenACCClauseWithIntExprs(K, BeginLoc, LParenLoc, EndLoc),
+        IntExpr(IntExpr) {
+    setIntExprs(MutableArrayRef<Expr *>{&this->IntExpr, 1});
+  }
 
 public:
-  bool hasIntExpr() const { return IntExpr; }
-  const Expr *getIntExpr() const { return IntExpr; }
-
-  Expr *getIntExpr() { return IntExpr; };
-
-  child_range children() {
-    return child_range(reinterpret_cast<Stmt **>(&IntExpr),
-                       reinterpret_cast<Stmt **>(&IntExpr + 1));
+  bool hasIntExpr() const { return !getIntExprs().empty(); }
+  const Expr *getIntExpr() const {
+    return hasIntExpr() ? getIntExprs()[0] : nullptr;
   }
 
-  const_child_range children() const {
-    return const_child_range(reinterpret_cast<Stmt *const *>(&IntExpr),
-                             reinterpret_cast<Stmt *const *>(&IntExpr + 1));
-  }
+  Expr *getIntExpr() { return hasIntExpr() ? getIntExprs()[0] : nullptr; };
 };
 
 class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1bbe76ff6bd2ac..a95424862e63f4 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12293,4 +12293,9 @@ def note_acc_int_expr_conversion
     : Note<"conversion to %select{integral|enumeration}0 type %1">;
 def err_acc_int_expr_multiple_conversions
     : Error<"multiple conversions from expression type %0 to an integral type">;
+def err_acc_num_gangs_num_args
+    : Error<"%select{no|too many}0 integer expression arguments provided to "
+            "OpenACC 'num_gangs' "
+            "%select{|clause: '%1' directive expects maximum of %2, %3 were "
+            "provided}0">;
 } // end of sema component.
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 520e068f4ffd40..dd5792e7ca8c39 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -18,6 +18,7 @@
 VISIT_CLAUSE(Default)
 VISIT_CLAUSE(If)
 VISIT_CLAUSE(Self)
+VISIT_CLAUSE(NumGangs)
 VISIT_CLAUSE(NumWorkers)
 VISIT_CLAUSE(VectorLength)
 
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 72b2f958a5e622..7e95a52e6f34d5 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3644,10 +3644,22 @@ class Parser : public CodeCompletionHandler {
   /// Parses the clause of the 'bind' argument, which can be a string literal or
   /// an ID expression.
   ExprResult ParseOpenACCBindClauseArgument();
+
+  /// A type to represent the state of parsing after an attempt to parse an
+  /// OpenACC int-expr. This is useful to determine whether an int-expr list can
+  /// continue parsing after a failed int-expr.
+  using OpenACCIntExprParseResult =
+      std::pair<ExprResult, OpenACCParseCanContinue>;
   /// Parses the clause kind of 'int-expr', which can be any integral
   /// expression.
-  ExprResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
-                                 SourceLocation Loc);
+  OpenACCIntExprParseResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
+                                                OpenACCClauseKind CK,
+                                                SourceLocation Loc);
+  /// Parses the argument list for 'num_gangs', which allows up to 3
+  /// 'int-expr's.
+  bool ParseOpenACCIntExprList(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
+                               SourceLocation Loc,
+                               llvm::SmallVector<Expr *> &IntExprs);
   /// Parses the 'device-type-list', which is a list of identifiers.
   bool ParseOpenACCDeviceTypeList();
   /// Parses the 'async-argument', which is an integral value with two
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 023722049732af..ea28617f79b81b 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -93,14 +93,16 @@ class SemaOpenACC : public SemaBase {
     }
 
     unsigned getNumIntExprs() const {
-      assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       return std::get<IntExprDetails>(Details).IntExprs.size();
     }
 
     ArrayRef<Expr *> getIntExprs() {
-      assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       return std::get<IntExprDetails>(Details).IntExprs;
@@ -134,11 +136,19 @@ class SemaOpenACC : public SemaBase {
     }
 
     void setIntExprDetails(ArrayRef<Expr *> IntExprs) {
-      assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
     }
+    void setIntExprDetails(llvm::SmallVector<Expr *> &&IntExprs) {
+      assert((ClauseKind == OpenACCClauseKind::NumGangs ||
+              ClauseKind == OpenACCClauseKind::NumWorkers ||
+              ClauseKind == OpenACCClauseKind::VectorLength) &&
+             "Parsed clause kind does not have a int exprs");
+      Details = IntExprDetails{IntExprs};
+    }
   };
 
   SemaOpenACC(Sema &S);
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 75334223e073c3..6cd5b28802187d 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -124,6 +124,16 @@ OpenACCVectorLengthClause::Create(const ASTContext &C, SourceLocation BeginLoc,
       OpenACCVectorLengthClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
 }
 
+OpenACCNumGangsClause *OpenACCNumGangsClause::Create(const ASTContext &C,
+                                                     SourceLocation BeginLoc,
+                                                     SourceLocation LParenLoc,
+                                                     ArrayRef<Expr *> IntExprs,
+                                                     SourceLocation EndLoc) {
+  void *Mem = C.Allocate(
+      OpenACCNumGangsClause::totalSizeToAlloc<Expr *>(IntExprs.size()));
+  return new (Mem) OpenACCNumGangsClause(BeginLoc, LParenLoc, IntExprs, EndLoc);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenACC clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -141,6 +151,12 @@ void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
     OS << "(" << CondExpr << ")";
 }
 
+void OpenACCClausePrinter::VisitNumGangsClause(const OpenACCNumGangsClause &C) {
+  OS << "num_gangs(";
+  llvm::interleaveComma(C.getIntExprs(), OS);
+  OS << ")";
+}
+
 void OpenACCClausePrinter::VisitNumWorkersClause(
     const OpenACCNumWorkersClause &C) {
   OS << "num_workers(" << C.getIntExpr() << ")";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 8138ae3a244a8b..c81724f84dd9ce 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2497,6 +2497,12 @@ void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
     Profiler.VisitStmt(Clause.getConditionExpr());
 }
 
+void OpenACCClauseProfiler::VisitNumGangsClause(
+    const OpenACCNumGangsClause &Clause) {
+  for (auto *E : Clause.getIntExprs())
+    Profiler.VisitStmt(E);
+}
+
 void OpenACCClauseProfiler::VisitNumWorkersClause(
     const OpenACCNumWorkersClause &Clause) {
   assert(Clause.hasIntExpr() && "num_workers clause requires a valid int expr");
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index e5a8b285715b30..8f0a9a9b0ed0bc 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -399,6 +399,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
       break;
     case OpenACCClauseKind::If:
     case OpenACCClauseKind::Self:
+    case OpenACCClauseKind::NumGangs:
     case OpenACCClauseKind::NumWorkers:
     case OpenACCClauseKind::VectorLength:
       // The condition expression will be printed as a part of the 'children',
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 757417f75c9636..9b7cd1668f7c17 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -632,16 +632,54 @@ Parser::ParseOpenACCClauseList(OpenACCDirectiveKind DirKind) {
   return Clauses;
 }
 
-ExprResult Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
-                                       OpenACCClauseKind CK,
-                                       SourceLocation Loc) {
-  ExprResult ER =
-      getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
+Parser::OpenACCIntExprParseResult
+Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
+                            SourceLocation Loc) {
+  ExprResult ER = ParseAssignmentExpression();
 
+  // If the actual parsing failed, we don't know the state of the parse, so
+  // don't try to continue.
   if (!ER.isUsable())
-    return ER;
+    return {ER, OpenACCParseCanContinue::Cannot};
+
+  // Parsing can continue after the initial assignment expression parsing, so
+  // even if there was a typo, we can continue.
+  ER = getActions().CorrectDelayedTyposInExpr(ER);
+  if (!ER.isUsable())
+    return {ER, OpenACCParseCanContinue::Can};
+
+  return {getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get()),
+          OpenACCParseCanContinue::Can};
+}
+
+bool Parser::ParseOpenACCIntExprList(OpenACCDirectiveKind DK,
+                                     OpenACCClauseKind CK, SourceLocation Loc,
+                                     llvm::SmallVector<Expr *> &IntExprs) {
+  OpenACCIntExprParseResult CurResult = ParseOpenACCIntExpr(DK, CK, Loc);
+
+  if (!CurResult.first.isUsable() &&
+      CurResult.second == OpenACCParseCanContinue::Cannot) {
+    SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
+              Parser::StopBeforeMatch);
+    return true;
+  }
+
+  IntExprs.push_back(CurResult.first.get());
+
+  while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
+    ExpectAndConsume(tok::comma);
+
+    CurResult = ParseOpenACCIntExpr(DK, CK, Loc);
 
-  return getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get());
+    if (!CurResult.first.isUsable() &&
+        CurResult.second == OpenACCParseCanContinue::Cannot) {
+      SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
+                Parser::StopBeforeMatch);
+      return true;
+    }
+    IntExprs.push_back(CurResult.first.get());
+  }
+  return false;
 }
 
 bool Parser::ParseOpenACCClauseVarList(OpenACCClauseKind Kind) {
@@ -761,7 +799,7 @@ bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
     ConsumeToken();
     return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
                                OpenACCClauseKind::Gang, GangLoc)
-        .isInvalid();
+        .first.isInvalid();
   }
 
   if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Num, getCurToken()) &&
@@ -773,7 +811,7 @@ bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
   // This is just the 'num' case where 'num' is optional.
   return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
                              OpenACCClauseKind::Gang, GangLoc)
-      .isInvalid();
+      .first.isInvalid();
 }
 
 bool Parser::ParseOpenACCGangArgList(SourceLocation GangLoc) {
@@ -946,13 +984,25 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
       }
       break;
     }
-    case OpenACCClauseKind::NumGangs:
+    case OpenACCClauseKind::NumGangs: {
+      llvm::SmallVector<Expr *> IntExprs;
+
+      if (ParseOpenACCIntExprList(OpenACCDirectiveKind::Invalid,
+                                  OpenACCClauseKind::NumGangs, ClauseLoc,
+                                  IntExprs)) {
+        Parens.skipToEnd();
+        return OpenACCCanContinue();
+      }
+      ParsedClause.setIntExprDetails(std::move(IntExprs));
+      break;
+    }
     case OpenACCClauseKind::NumWorkers:
     case OpenACCClauseKind::DeviceNum:
     case OpenACCClauseKind::DefaultAsync:
     case OpenACCClauseKind::VectorLength: {
       ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
-                                               ClauseKind, ClauseLoc);
+                                               ClauseKind, ClauseLoc)
+                               .first;
       if (IntExpr.isInvalid()) {
         Parens.skipToEnd();
         return OpenACCCanContinue();
@@ -1017,7 +1067,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
                                                : OpenACCSpecialTokenKind::Num,
                                            ClauseKind);
         ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
-                                                 ClauseKind, ClauseLoc);
+                                                 ClauseKind, ClauseLoc)
+                                 .first;
         if (IntExpr.isInvalid()) {
           Parens.skipToEnd();
           return OpenACCCanContinue();
@@ -1081,11 +1132,13 @@ bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
     // Consume colon.
     ConsumeToken();
 
-    ExprResult IntExpr = ParseOpenACCIntExpr(
-        IsDirective ? OpenACCDirectiveKind::Wait
-                    : OpenACCDirectiveKind::Invalid,
-        IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
-        Loc);
+    ExprResult IntExpr =
+        ParseOpenACCIntExpr(IsDirective ? OpenACCDirectiveKind::Wait
+                                        : OpenACCDirectiveKind::Invalid,
+                            IsDirective ? OpenACCClauseKind::Invalid
+                                        : OpenACCClauseKind::Wait,
+                            Loc)
+            .first;
     if (IntExpr.isInvalid())
       return true;
 
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index 190739fa02e932..ba69e71e30a181 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -91,6 +91,7 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
     default:
       return false;
     }
+  case OpenACCClauseKind::NumGangs:
   case OpenACCClauseKind::NumWorkers:
   case OpenACCClauseKind::VectorLength:
     switch (DirectiveKind) {
@@ -230,6 +231,40 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
         getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
         Clause.getConditionExpr(), Clause.getEndLoc());
   }
+  case OpenACCClauseKind::NumGangs: {
+    // Restrictions only properly implemented on 'compute' constructs, and
+    // 'compute' constructs are the only construct that can do anything with
+    // this yet, so skip/treat as unimplemented in this case.
+    if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()))
+      break;
+
+    // There is no prose in the standard that says duplicates aren't allowed,
+    // but this diagnostic is present in other compilers, as well as makes
+    // sense.
+    if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
+      return nullptr;
+
+    if (Clause.getIntExprs().empty())
+      Diag(Clause.getBeginLoc(), diag::err_acc_num_gangs_num_args)
+          << /*NoArgs=*/0;
+
+    unsigned MaxArgs =
+        (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel ||
+         Clause.getDirectiveKind() == OpenACCDirectiveKind::ParallelLoop)
+            ? 3
+            : 1;
+    if (Clause.getIntExprs().size() > MaxArgs)
+      Diag(Clause.getBeginLoc(), diag::err_acc_num_gangs_num_args)
+          << /*NoArgs=*/1 << Clause.getDirectiveKind() << MaxArgs
+          << Clause.getIntExprs().size();
+
+    // Create the AST node for the clause even if the number of expressions is
+    // incorrect.
+    return OpenACCNumGangsClause::Create(
+        getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
+        Clause.getIntExprs(), Clause.getEndLoc());
+    break;
+  }
   case OpenACCClauseKind::NumWorkers: {
     // Restrictions only properly implemented on 'compute' constructs, and
     // 'compute' constructs are the only construct that can do anything with
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index ade33ec65038fd..10ce226...
[truncated]

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG with a nit

/// 'int-expr's.
bool ParseOpenACCIntExprList(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
SourceLocation Loc,
llvm::SmallVector<Expr *> &IntExprs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::SmallVector<Expr *> &IntExprs);
llvm::SmallVectorImpl<Expr *> &IntExprs);

@erichkeane erichkeane merged commit dc20a0e into llvm:main Apr 22, 2024
4 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:openacc Clang OpenACC Implementation clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants