Skip to content

Conversation

SunilKuravinakop
Copy link
Contributor

Support for dispatch construct (Sema & Codegen) support. Support for clauses: depend, novariants & nocontext.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. flang:openmp clang:openmp OpenMP related changes to Clang labels Nov 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2024

@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-flang-openmp

Author: None (SunilKuravinakop)

Changes

Support for dispatch construct (Sema & Codegen) support. Support for clauses: depend, novariants & nocontext.


Patch is 30.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117904.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2)
  • (modified) clang/include/clang/Basic/OpenMPKinds.h (+6)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+7)
  • (modified) clang/lib/Basic/OpenMPKinds.cpp (+5)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+5)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+272-3)
  • (modified) clang/test/Misc/warning-flags.c (+2-1)
  • (added) clang/test/OpenMP/dispatch_codegen.cpp (+359)
  • (removed) clang/test/OpenMP/dispatch_unsupported.c (-7)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPContext.h (+6)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index c1cdd811db446d..61d684461d785f 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11769,6 +11769,8 @@ def err_omp_clause_requires_dispatch_construct : Error<
   "'%0' clause requires 'dispatch' context selector">;
 def err_omp_append_args_with_varargs : Error<
   "'append_args' is not allowed with varargs functions">;
+def warn_omp_dispatch_clause_novariants_nocontext : Warning<
+  "only 'novariants' clause is supported when 'novariants' & 'nocontext' clauses occur on the same dispatch construct">;
 def err_openmp_vla_in_task_untied : Error<
   "variable length arrays are not supported in OpenMP tasking regions with 'untied' clause">;
 def warn_omp_unterminated_declare_target : Warning<
diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h
index 900ad6ca6d66f6..7579fab43dbb19 100644
--- a/clang/include/clang/Basic/OpenMPKinds.h
+++ b/clang/include/clang/Basic/OpenMPKinds.h
@@ -269,6 +269,12 @@ bool isOpenMPTaskLoopDirective(OpenMPDirectiveKind DKind);
 /// parallel', otherwise - false.
 bool isOpenMPParallelDirective(OpenMPDirectiveKind DKind);
 
+/// Checks if the specified directive is a dispatch-kind directive.
+/// \param DKind Specified directive.
+/// \return true - the directive is a dispatch-like directive like 'omp
+/// dispatch', otherwise - false.
+bool isOpenMPDispatchDirective(OpenMPDirectiveKind DKind);
+
 /// Checks if the specified directive is a target code offload directive.
 /// \param DKind Specified directive.
 /// \return true - the directive is a target code offload directive like
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 3d1cc4fab1c10f..80cee9e7583051 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1462,6 +1462,13 @@ class SemaOpenMP : public SemaBase {
                                            : OMPDeclareVariantScopes.back().TI;
   }
 
+  StmtResult transformDispatchDirective(OpenMPDirectiveKind Kind,
+                                        const DeclarationNameInfo &DirName,
+                                        OpenMPDirectiveKind CancelRegion,
+                                        ArrayRef<OMPClause *> Clauses,
+                                        Stmt *AStmt, SourceLocation StartLoc,
+                                        SourceLocation EndLoc);
+
   /// The current `omp begin/end declare variant` scopes.
   SmallVector<OMPDeclareVariantScope, 4> OMPDeclareVariantScopes;
 
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 62a13f01481b28..44ee63df46adb5 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -621,6 +621,11 @@ bool clang::isOpenMPParallelDirective(OpenMPDirectiveKind DKind) {
          llvm::is_contained(getLeafConstructs(DKind), OMPD_parallel);
 }
 
+bool clang::isOpenMPDispatchDirective(OpenMPDirectiveKind DKind) {
+  return DKind == OMPD_dispatch ||
+         llvm::is_contained(getLeafConstructs(DKind), OMPD_target);
+}
+
 bool clang::isOpenMPTargetExecutionDirective(OpenMPDirectiveKind DKind) {
   return DKind == OMPD_target ||
          llvm::is_contained(getLeafConstructs(DKind), OMPD_target);
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 698baf853507f4..2cd2a650151c87 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -417,7 +417,8 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
     EmitOMPInteropDirective(cast<OMPInteropDirective>(*S));
     break;
   case Stmt::OMPDispatchDirectiveClass:
-    CGM.ErrorUnsupported(S, "OpenMP dispatch directive");
+    // CGM.ErrorUnsupported(S, "OpenMP dispatch directive");
+    EmitOMPDispatchDirective(cast<OMPDispatchDirective>(*S));
     break;
   case Stmt::OMPScopeDirectiveClass:
     EmitOMPScopeDirective(cast<OMPScopeDirective>(*S));
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 390516fea38498..7d77db0ba2bb12 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -4452,6 +4452,11 @@ void CodeGenFunction::EmitOMPMasterDirective(const OMPMasterDirective &S) {
   emitMaster(*this, S);
 }
 
+void CodeGenFunction::EmitOMPDispatchDirective(const OMPDispatchDirective &S) {
+
+  EmitStmt(S.getAssociatedStmt());
+}
+
 static void emitMasked(CodeGenFunction &CGF, const OMPExecutableDirective &S) {
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
     Action.Enter(CGF);
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 5c4d76c2267a77..0adf51f53aceaf 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3830,6 +3830,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitOMPSectionDirective(const OMPSectionDirective &S);
   void EmitOMPSingleDirective(const OMPSingleDirective &S);
   void EmitOMPMasterDirective(const OMPMasterDirective &S);
+  void EmitOMPDispatchDirective(const OMPDispatchDirective &S);
   void EmitOMPMaskedDirective(const OMPMaskedDirective &S);
   void EmitOMPCriticalDirective(const OMPCriticalDirective &S);
   void EmitOMPParallelForDirective(const OMPParallelForDirective &S);
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 976d48e1249136..943e47e2d96b25 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -4205,6 +4205,8 @@ static void handleDeclareVariantConstructTrait(DSAStackTy *Stack,
   SmallVector<llvm::omp::TraitProperty, 8> Traits;
   if (isOpenMPTargetExecutionDirective(DKind))
     Traits.emplace_back(llvm::omp::TraitProperty::construct_target_target);
+  if (isOpenMPDispatchDirective(DKind))
+    Traits.emplace_back(llvm::omp::TraitProperty::construct_dispatch_dispatch);
   if (isOpenMPTeamsDirective(DKind))
     Traits.emplace_back(llvm::omp::TraitProperty::construct_teams_teams);
   if (isOpenMPParallelDirective(DKind))
@@ -5965,6 +5967,244 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema &SemaRef) {
   return Checker.teamsLoopCanBeParallelFor();
 }
 
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+  Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+  if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+    if (auto *CapturedExprDecl =
+            dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+      // Retrieve the initial expression from the captured expression
+      return CapturedExprDecl->getInit();
+    }
+  }
+  return nullptr;
+}
+
+static Expr *copy_removePseudoObjectExpr(const ASTContext &Context, Expr *E,
+                                         SemaOpenMP *SemaPtr, bool NoContext) {
+
+  BinaryOperator *BinaryCopyOpr = NULL;
+  bool BinaryOp = false;
+  if (E->getStmtClass() == Stmt::BinaryOperatorClass) {
+    BinaryOp = true;
+    BinaryOperator *E_BinOpr = static_cast<BinaryOperator *>(E);
+    BinaryCopyOpr = BinaryOperator::Create(
+        Context, E_BinOpr->getLHS(), E_BinOpr->getRHS(), E_BinOpr->getOpcode(),
+        E_BinOpr->getType(), E_BinOpr->getValueKind(),
+        E_BinOpr->getObjectKind(), E_BinOpr->getOperatorLoc(),
+        FPOptionsOverride());
+    E = BinaryCopyOpr->getRHS();
+  }
+
+  // Change PseudoObjectExpr to a direct call
+  if (PseudoObjectExpr *PO = dyn_cast<PseudoObjectExpr>(E))
+    E = *((PO->semantics_begin()) - 1);
+
+  // Add new Traits to direct call to convert it to new PseudoObjectExpr
+  // This converts Traits for the function call from under "dispatch" to traits
+  // of direct function call not under "dispatch".
+  if (NoContext) {
+    // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall()
+    CallExpr *CallExprWithinStmt = dyn_cast<CallExpr>(E);
+    int NumArgs = CallExprWithinStmt->getNumArgs();
+    clang::Expr **Args = CallExprWithinStmt->getArgs();
+    ExprResult er = SemaPtr->ActOnOpenMPCall(
+        CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(),
+        CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs),
+        CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr));
+    E = er.get();
+  }
+
+  if (BinaryOp) {
+    BinaryCopyOpr->setRHS(E);
+    return BinaryCopyOpr;
+  }
+
+  return E;
+}
+
+static StmtResult combine2Stmts(ASTContext &context, Stmt *first,
+                                Stmt *second) {
+
+  llvm::SmallVector<Stmt *, 2> newCombinedStmts;
+  newCombinedStmts.push_back(first);
+  newCombinedStmts.push_back(second); // Adding foo();
+  llvm::ArrayRef<Stmt *> ar(newCombinedStmts);
+  CompoundStmt *CombinedStmt = CompoundStmt::Create(
+      context, ar, FPOptionsOverride(), SourceLocation(), SourceLocation());
+  StmtResult FinalStmts(CombinedStmt);
+  return FinalStmts;
+}
+
+template <typename SpecificClause>
+static bool hasClausesOfKind(ArrayRef<OMPClause *> Clauses) {
+  auto ClausesOfKind =
+      OMPExecutableDirective::getClausesOfKind<SpecificClause>(Clauses);
+  return ClausesOfKind.begin() != ClausesOfKind.end();
+}
+
+// Get a CapturedStmt with direct call to function.
+// If there is a PseudoObjectExpr under the CapturedDecl
+// choose the first call under it for the direct call to function
+static StmtResult CloneNewCapturedStmtForDirectCall(const ASTContext &Context,
+                                                    Stmt *StmtP,
+                                                    SemaOpenMP *SemaPtr,
+                                                    bool NoContext) {
+  if (StmtP->getStmtClass() == Stmt::CapturedStmtClass) {
+    CapturedStmt *AStmt = static_cast<CapturedStmt *>(StmtP);
+    CapturedDecl *CDecl = AStmt->getCapturedDecl();
+    Stmt *S = cast<CapturedStmt>(AStmt)->getCapturedStmt();
+    auto *E = dyn_cast<Expr>(S);
+    E = copy_removePseudoObjectExpr(Context, E, SemaPtr, NoContext);
+
+    // Copy Current Captured Decl to a New Captured Decl for noting the
+    // Annotation
+    CapturedDecl *NewDecl =
+        CapturedDecl::Create(const_cast<ASTContext &>(Context),
+                             CDecl->getDeclContext(), CDecl->getNumParams());
+    NewDecl->setBody(static_cast<Stmt *>(E));
+    for (unsigned i = 0; i < CDecl->getNumParams(); ++i) {
+      if (i != CDecl->getContextParamPosition())
+        NewDecl->setParam(i, CDecl->getParam(i));
+      else
+        NewDecl->setContextParam(i, CDecl->getContextParam());
+    }
+
+    // Create a New Captured Stmt containing the New Captured Decl
+    SmallVector<CapturedStmt::Capture, 4> Captures;
+    SmallVector<Expr *, 4> CaptureInits;
+    for (auto capture : AStmt->captures())
+      Captures.push_back(capture);
+    for (auto capture_init : AStmt->capture_inits())
+      CaptureInits.push_back(capture_init);
+    CapturedStmt *NewStmt = CapturedStmt::Create(
+        Context, AStmt->getCapturedStmt(), AStmt->getCapturedRegionKind(),
+        Captures, CaptureInits, NewDecl,
+        const_cast<RecordDecl *>(AStmt->getCapturedRecordDecl()));
+
+    return NewStmt;
+  }
+  return static_cast<Stmt *>(NULL);
+}
+
+StmtResult SemaOpenMP::transformDispatchDirective(
+    OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
+    OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
+    Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) {
+
+  StmtResult RetValue;
+  std::vector<OMPClause *> DependVector;
+  for (auto C :
+       OMPExecutableDirective::getClausesOfKind<OMPDependClause>(Clauses)) {
+    auto C1 = const_cast<OMPDependClause *>(C);
+    if (OMPDependClause *DependClause = dyn_cast<OMPDependClause>(C1)) {
+      DependVector.push_back(DependClause);
+    }
+  }
+  llvm::ArrayRef<OMPClause *> DependCArray(DependVector);
+
+  // #pragma omp dispatch depend() is changed to #pragma omp taskwait depend()
+  // This is done by calling ActOnOpenMPExecutableDirective() for the
+  // new taskwait directive.
+  StmtResult DispatchDepend2taskwait =
+      ActOnOpenMPExecutableDirective(OMPD_taskwait, DirName, CancelRegion,
+                                     DependCArray, NULL, StartLoc, EndLoc);
+
+  if (OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses)) {
+
+    if (OMPExecutableDirective::getSingleClause<OMPNocontextClause>(Clauses)) {
+      Diag(StartLoc, diag::warn_omp_dispatch_clause_novariants_nocontext);
+    }
+
+    const OMPNovariantsClause *NoVariantsC =
+        OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses);
+    // #pragma omp dispatch novariants(c2) depend(out: x)
+    // foo();
+    // becomes:
+    // #pragma omp taskwait depend(out: x)
+    // if (c2) {
+    //    foo();
+    // } else {
+    //    #pragma omp dispatch
+    //    foo(); <--- foo() is replaced with foo_variant() in CodeGen
+    // }
+    Expr *Cond = getInitialExprFromCapturedExpr(NoVariantsC->getCondition());
+    StmtResult ThenStmt =
+        CloneNewCapturedStmtForDirectCall(getASTContext(), AStmt, this, false);
+    SmallVector<OMPClause *, 5> DependClauses;
+    StmtResult ElseStmt = ActOnOpenMPExecutableDirective(
+        Kind, DirName, CancelRegion, DependClauses, AStmt, StartLoc, EndLoc);
+    IfStmt *IfElseStmt =
+        IfStmt::Create(getASTContext(), StartLoc, IfStatementKind::Ordinary,
+                       nullptr, // Init
+                       nullptr, // Condition Var Declaration
+                       Cond,
+                       StartLoc, // Source Location Left Paranthesis
+                       StartLoc, // Source Location Right Paranthesis
+                       ThenStmt.get(), StartLoc, ElseStmt.get());
+    if (hasClausesOfKind<OMPDependClause>(Clauses)) {
+      StmtResult FinalStmts = combine2Stmts(
+          getASTContext(), DispatchDepend2taskwait.get(), IfElseStmt);
+      RetValue = FinalStmts;
+    } else {
+      RetValue = IfElseStmt;
+    }
+  } else if (OMPExecutableDirective::getSingleClause<OMPNocontextClause>(
+                 Clauses)) {
+
+    const OMPNocontextClause *NoContextC =
+        OMPExecutableDirective::getSingleClause<OMPNocontextClause>(Clauses);
+    Expr *Cond = getInitialExprFromCapturedExpr(NoContextC->getCondition());
+    // #pragma omp dispatch depend(out: x) nocontext(c2)
+    // foo();
+    // becomes:
+    // #pragma omp taskwait depend(out: x)
+    // if (c2) {
+    //    foo();
+    // } else {
+    //    #pragma omp dispatch
+    //    foo();
+    // }
+
+    StmtResult ThenStmt =
+        CloneNewCapturedStmtForDirectCall(getASTContext(), AStmt, this, true);
+
+    SmallVector<OMPClause *, 5> DependClauses;
+    StmtResult ElseStmt = ActOnOpenMPExecutableDirective(
+        Kind, DirName, CancelRegion, DependClauses, AStmt, StartLoc, EndLoc);
+    IfStmt *IfElseStmt =
+        IfStmt::Create(getASTContext(), StartLoc, IfStatementKind::Ordinary,
+                       nullptr, // Init
+                       nullptr, // Condition Var Declaration
+                       Cond,
+                       StartLoc, // Source Location Left Paranthesis
+                       StartLoc, // Source Location Right Paranthesis
+                       ThenStmt.get(), StartLoc, ElseStmt.get());
+    if (hasClausesOfKind<OMPDependClause>(Clauses)) {
+      StmtResult FinalStmts = combine2Stmts(
+          getASTContext(), DispatchDepend2taskwait.get(), IfElseStmt);
+      RetValue = FinalStmts;
+    } else {
+      RetValue = IfElseStmt;
+    }
+  } else if (hasClausesOfKind<OMPDependClause>(Clauses)) {
+    // Only:
+    // #pragma omp dispatch depend(out: x)
+    // foo();
+    // to
+    // #pragma omp taskwait depend(out: x)
+    // foo();
+    // StmtResult NewAStmt =
+    // CloneNewCapturedStmtForDirectCall(getASTContext(), AStmt, false);
+    StmtResult FinalStmts =
+        combine2Stmts(getASTContext(), DispatchDepend2taskwait.get(), AStmt);
+    RetValue = FinalStmts;
+  }
+  return RetValue;
+}
+
 StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
     OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
     OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
@@ -5979,6 +6219,20 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
           OMPExecutableDirective::getSingleClause<OMPBindClause>(Clauses))
     BindKind = BC->getBindKind();
 
+  if ((Kind == OMPD_dispatch) && (Clauses.size() > 0)) {
+
+    bool UnSupportedClause = false;
+    for (OMPClause *C : Clauses) {
+      if (!((C->getClauseKind() == OMPC_novariants) ||
+            (C->getClauseKind() == OMPC_nocontext) ||
+            (C->getClauseKind() == OMPC_depend)))
+        UnSupportedClause = true;
+    }
+    if (!UnSupportedClause)
+      return transformDispatchDirective(Kind, DirName, CancelRegion, Clauses,
+                                        AStmt, StartLoc, EndLoc);
+  }
+
   if (Kind == OMPD_loop && BindKind == OMPC_BIND_unknown) {
     const OpenMPDirectiveKind ParentDirective = DSAStack->getParentDirective();
 
@@ -7209,8 +7463,10 @@ ExprResult SemaOpenMP::ActOnOpenMPCall(ExprResult Call, Scope *Scope,
     Exprs.erase(Exprs.begin() + BestIdx);
   } while (!VMIs.empty());
 
-  if (!NewCall.isUsable())
+  if (!NewCall.isUsable()) {
+    fprintf(stdout, "Returning Call, NewCall is not usable\n");
     return Call;
+  }
   return PseudoObjectExpr::Create(getASTContext(), CE, {NewCall.get()}, 0);
 }
 
@@ -10520,8 +10776,17 @@ StmtResult SemaOpenMP::ActOnOpenMPSectionDirective(Stmt *AStmt,
                                      DSAStack->isCancelRegion());
 }
 
+static Expr *removePseudoObjectExpr(Expr *E) {
+  // PseudoObjectExpr is a Trait for dispatch containing the
+  // function name and its variant. Returning the function name.
+  if (auto *PO = dyn_cast<PseudoObjectExpr>(E))
+    E = *((PO->semantics_begin()) - 1);
+  return E;
+}
+
 static Expr *getDirectCallExpr(Expr *E) {
   E = E->IgnoreParenCasts()->IgnoreImplicit();
+  E = removePseudoObjectExpr(E);
   if (auto *CE = dyn_cast<CallExpr>(E))
     if (CE->getDirectCallee())
       return E;
@@ -10556,15 +10821,19 @@ SemaOpenMP::ActOnOpenMPDispatchDirective(ArrayRef<OMPClause *> Clauses,
     E = E->IgnoreParenCasts()->IgnoreImplicit();
 
     if (auto *BO = dyn_cast<BinaryOperator>(E)) {
-      if (BO->getOpcode() == BO_Assign)
+      if (BO->getOpcode() == BO_Assign) {
         TargetCall = getDirectCallExpr(BO->getRHS());
+      }
     } else {
       if (auto *COCE = dyn_cast<CXXOperatorCallExpr>(E))
         if (COCE->getOperator() == OO_Equal)
           TargetCall = getDirectCallExpr(COCE->getArg(1));
-      if (!TargetCall)
+      if (!TargetCall) {
+        E = removePseudoObjectExpr(E);
         TargetCall = getDirectCallExpr(E);
+      }
     }
+
     if (!TargetCall) {
       Diag(E->getBeginLoc(), diag::err_omp_dispatch_statement_call);
       return StmtError();
diff --git a/clang/test/Misc/warning-flags.c b/clang/test/Misc/warning-flags.c
index 1fd02440833359..5f3b54d7be8e29 100644
--- a/clang/test/Misc/warning-flags.c
+++ b/clang/test/Misc/warning-flags.c
@@ -18,7 +18,7 @@ This test serves two purposes:
 
 The list of warnings below should NEVER grow.  It should gradually shrink to 0.
 
-CHECK: Warnings without flags (61):
+CHECK: Warnings without flags (62):
 
 CHECK-NEXT:   ext_expected_semi_decl_list
 CHECK-NEXT:   ext_missing_whitespace_after_macro_name
@@ -65,6 +65,7 @@ CHECK-NEXT:   warn_no_constructor_for_refconst
 CHECK-NEXT:   warn_not_compound_assign
 CHECK-NEXT:   warn_objc_property_copy_missing_on_block
 CHECK-NEXT:   warn_objc_protocol_qualifier_missing_id
+CHECK-NEXT:   warn_omp_dispatch_clause_novariants_nocontext
 CHECK-NEXT:   warn_on_superclass_use
 CHECK-NEXT:   warn_pp_convert_to_positive
 CHECK-NEXT:   warn_pp_expr_overflow
diff --git a/clang/test/OpenMP/dispatch_codegen.cpp b/clang/test/OpenMP/dispatch_codegen.cpp
new file mode 100644
index 00000000000000..3a50ccbb99873a
--- /dev/null
+++ b/clang/test/OpenMP/dispatch_codegen.cpp
@@ -0,0 +1,359 @@
+// expected-no-diagnostics
+// RUN: %clang_cc1 -DCK1 -verify -fopenmp -x c++ %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple x86_64-unknown-unknown -emit-llvm %s -o - | FileCheck %s
+
+int foo_variant_dispatch(int x, int y) {
+   return x+2;
+}
+
+int foo_variant_allCond(int x, int y) {
+   return x+3;
+}
+
+#pragma omp declare variant(foo_v...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2024

@llvm/pr-subscribers-clang

Author: None (SunilKuravinakop)

Changes

Support for dispatch construct (Sema & Codegen) support. Support for clauses: depend, novariants & nocontext.


Patch is 30.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117904.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2)
  • (modified) clang/include/clang/Basic/OpenMPKinds.h (+6)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+7)
  • (modified) clang/lib/Basic/OpenMPKinds.cpp (+5)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+5)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+272-3)
  • (modified) clang/test/Misc/warning-flags.c (+2-1)
  • (added) clang/test/OpenMP/dispatch_codegen.cpp (+359)
  • (removed) clang/test/OpenMP/dispatch_unsupported.c (-7)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPContext.h (+6)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index c1cdd811db446d..61d684461d785f 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11769,6 +11769,8 @@ def err_omp_clause_requires_dispatch_construct : Error<
   "'%0' clause requires 'dispatch' context selector">;
 def err_omp_append_args_with_varargs : Error<
   "'append_args' is not allowed with varargs functions">;
+def warn_omp_dispatch_clause_novariants_nocontext : Warning<
+  "only 'novariants' clause is supported when 'novariants' & 'nocontext' clauses occur on the same dispatch construct">;
 def err_openmp_vla_in_task_untied : Error<
   "variable length arrays are not supported in OpenMP tasking regions with 'untied' clause">;
 def warn_omp_unterminated_declare_target : Warning<
diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h
index 900ad6ca6d66f6..7579fab43dbb19 100644
--- a/clang/include/clang/Basic/OpenMPKinds.h
+++ b/clang/include/clang/Basic/OpenMPKinds.h
@@ -269,6 +269,12 @@ bool isOpenMPTaskLoopDirective(OpenMPDirectiveKind DKind);
 /// parallel', otherwise - false.
 bool isOpenMPParallelDirective(OpenMPDirectiveKind DKind);
 
+/// Checks if the specified directive is a dispatch-kind directive.
+/// \param DKind Specified directive.
+/// \return true - the directive is a dispatch-like directive like 'omp
+/// dispatch', otherwise - false.
+bool isOpenMPDispatchDirective(OpenMPDirectiveKind DKind);
+
 /// Checks if the specified directive is a target code offload directive.
 /// \param DKind Specified directive.
 /// \return true - the directive is a target code offload directive like
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 3d1cc4fab1c10f..80cee9e7583051 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1462,6 +1462,13 @@ class SemaOpenMP : public SemaBase {
                                            : OMPDeclareVariantScopes.back().TI;
   }
 
+  StmtResult transformDispatchDirective(OpenMPDirectiveKind Kind,
+                                        const DeclarationNameInfo &DirName,
+                                        OpenMPDirectiveKind CancelRegion,
+                                        ArrayRef<OMPClause *> Clauses,
+                                        Stmt *AStmt, SourceLocation StartLoc,
+                                        SourceLocation EndLoc);
+
   /// The current `omp begin/end declare variant` scopes.
   SmallVector<OMPDeclareVariantScope, 4> OMPDeclareVariantScopes;
 
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 62a13f01481b28..44ee63df46adb5 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -621,6 +621,11 @@ bool clang::isOpenMPParallelDirective(OpenMPDirectiveKind DKind) {
          llvm::is_contained(getLeafConstructs(DKind), OMPD_parallel);
 }
 
+bool clang::isOpenMPDispatchDirective(OpenMPDirectiveKind DKind) {
+  return DKind == OMPD_dispatch ||
+         llvm::is_contained(getLeafConstructs(DKind), OMPD_target);
+}
+
 bool clang::isOpenMPTargetExecutionDirective(OpenMPDirectiveKind DKind) {
   return DKind == OMPD_target ||
          llvm::is_contained(getLeafConstructs(DKind), OMPD_target);
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 698baf853507f4..2cd2a650151c87 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -417,7 +417,8 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
     EmitOMPInteropDirective(cast<OMPInteropDirective>(*S));
     break;
   case Stmt::OMPDispatchDirectiveClass:
-    CGM.ErrorUnsupported(S, "OpenMP dispatch directive");
+    // CGM.ErrorUnsupported(S, "OpenMP dispatch directive");
+    EmitOMPDispatchDirective(cast<OMPDispatchDirective>(*S));
     break;
   case Stmt::OMPScopeDirectiveClass:
     EmitOMPScopeDirective(cast<OMPScopeDirective>(*S));
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 390516fea38498..7d77db0ba2bb12 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -4452,6 +4452,11 @@ void CodeGenFunction::EmitOMPMasterDirective(const OMPMasterDirective &S) {
   emitMaster(*this, S);
 }
 
+void CodeGenFunction::EmitOMPDispatchDirective(const OMPDispatchDirective &S) {
+
+  EmitStmt(S.getAssociatedStmt());
+}
+
 static void emitMasked(CodeGenFunction &CGF, const OMPExecutableDirective &S) {
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
     Action.Enter(CGF);
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 5c4d76c2267a77..0adf51f53aceaf 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3830,6 +3830,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitOMPSectionDirective(const OMPSectionDirective &S);
   void EmitOMPSingleDirective(const OMPSingleDirective &S);
   void EmitOMPMasterDirective(const OMPMasterDirective &S);
+  void EmitOMPDispatchDirective(const OMPDispatchDirective &S);
   void EmitOMPMaskedDirective(const OMPMaskedDirective &S);
   void EmitOMPCriticalDirective(const OMPCriticalDirective &S);
   void EmitOMPParallelForDirective(const OMPParallelForDirective &S);
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 976d48e1249136..943e47e2d96b25 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -4205,6 +4205,8 @@ static void handleDeclareVariantConstructTrait(DSAStackTy *Stack,
   SmallVector<llvm::omp::TraitProperty, 8> Traits;
   if (isOpenMPTargetExecutionDirective(DKind))
     Traits.emplace_back(llvm::omp::TraitProperty::construct_target_target);
+  if (isOpenMPDispatchDirective(DKind))
+    Traits.emplace_back(llvm::omp::TraitProperty::construct_dispatch_dispatch);
   if (isOpenMPTeamsDirective(DKind))
     Traits.emplace_back(llvm::omp::TraitProperty::construct_teams_teams);
   if (isOpenMPParallelDirective(DKind))
@@ -5965,6 +5967,244 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema &SemaRef) {
   return Checker.teamsLoopCanBeParallelFor();
 }
 
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+  Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+  if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+    if (auto *CapturedExprDecl =
+            dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+      // Retrieve the initial expression from the captured expression
+      return CapturedExprDecl->getInit();
+    }
+  }
+  return nullptr;
+}
+
+static Expr *copy_removePseudoObjectExpr(const ASTContext &Context, Expr *E,
+                                         SemaOpenMP *SemaPtr, bool NoContext) {
+
+  BinaryOperator *BinaryCopyOpr = NULL;
+  bool BinaryOp = false;
+  if (E->getStmtClass() == Stmt::BinaryOperatorClass) {
+    BinaryOp = true;
+    BinaryOperator *E_BinOpr = static_cast<BinaryOperator *>(E);
+    BinaryCopyOpr = BinaryOperator::Create(
+        Context, E_BinOpr->getLHS(), E_BinOpr->getRHS(), E_BinOpr->getOpcode(),
+        E_BinOpr->getType(), E_BinOpr->getValueKind(),
+        E_BinOpr->getObjectKind(), E_BinOpr->getOperatorLoc(),
+        FPOptionsOverride());
+    E = BinaryCopyOpr->getRHS();
+  }
+
+  // Change PseudoObjectExpr to a direct call
+  if (PseudoObjectExpr *PO = dyn_cast<PseudoObjectExpr>(E))
+    E = *((PO->semantics_begin()) - 1);
+
+  // Add new Traits to direct call to convert it to new PseudoObjectExpr
+  // This converts Traits for the function call from under "dispatch" to traits
+  // of direct function call not under "dispatch".
+  if (NoContext) {
+    // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall()
+    CallExpr *CallExprWithinStmt = dyn_cast<CallExpr>(E);
+    int NumArgs = CallExprWithinStmt->getNumArgs();
+    clang::Expr **Args = CallExprWithinStmt->getArgs();
+    ExprResult er = SemaPtr->ActOnOpenMPCall(
+        CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(),
+        CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs),
+        CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr));
+    E = er.get();
+  }
+
+  if (BinaryOp) {
+    BinaryCopyOpr->setRHS(E);
+    return BinaryCopyOpr;
+  }
+
+  return E;
+}
+
+static StmtResult combine2Stmts(ASTContext &context, Stmt *first,
+                                Stmt *second) {
+
+  llvm::SmallVector<Stmt *, 2> newCombinedStmts;
+  newCombinedStmts.push_back(first);
+  newCombinedStmts.push_back(second); // Adding foo();
+  llvm::ArrayRef<Stmt *> ar(newCombinedStmts);
+  CompoundStmt *CombinedStmt = CompoundStmt::Create(
+      context, ar, FPOptionsOverride(), SourceLocation(), SourceLocation());
+  StmtResult FinalStmts(CombinedStmt);
+  return FinalStmts;
+}
+
+template <typename SpecificClause>
+static bool hasClausesOfKind(ArrayRef<OMPClause *> Clauses) {
+  auto ClausesOfKind =
+      OMPExecutableDirective::getClausesOfKind<SpecificClause>(Clauses);
+  return ClausesOfKind.begin() != ClausesOfKind.end();
+}
+
+// Get a CapturedStmt with direct call to function.
+// If there is a PseudoObjectExpr under the CapturedDecl
+// choose the first call under it for the direct call to function
+static StmtResult CloneNewCapturedStmtForDirectCall(const ASTContext &Context,
+                                                    Stmt *StmtP,
+                                                    SemaOpenMP *SemaPtr,
+                                                    bool NoContext) {
+  if (StmtP->getStmtClass() == Stmt::CapturedStmtClass) {
+    CapturedStmt *AStmt = static_cast<CapturedStmt *>(StmtP);
+    CapturedDecl *CDecl = AStmt->getCapturedDecl();
+    Stmt *S = cast<CapturedStmt>(AStmt)->getCapturedStmt();
+    auto *E = dyn_cast<Expr>(S);
+    E = copy_removePseudoObjectExpr(Context, E, SemaPtr, NoContext);
+
+    // Copy Current Captured Decl to a New Captured Decl for noting the
+    // Annotation
+    CapturedDecl *NewDecl =
+        CapturedDecl::Create(const_cast<ASTContext &>(Context),
+                             CDecl->getDeclContext(), CDecl->getNumParams());
+    NewDecl->setBody(static_cast<Stmt *>(E));
+    for (unsigned i = 0; i < CDecl->getNumParams(); ++i) {
+      if (i != CDecl->getContextParamPosition())
+        NewDecl->setParam(i, CDecl->getParam(i));
+      else
+        NewDecl->setContextParam(i, CDecl->getContextParam());
+    }
+
+    // Create a New Captured Stmt containing the New Captured Decl
+    SmallVector<CapturedStmt::Capture, 4> Captures;
+    SmallVector<Expr *, 4> CaptureInits;
+    for (auto capture : AStmt->captures())
+      Captures.push_back(capture);
+    for (auto capture_init : AStmt->capture_inits())
+      CaptureInits.push_back(capture_init);
+    CapturedStmt *NewStmt = CapturedStmt::Create(
+        Context, AStmt->getCapturedStmt(), AStmt->getCapturedRegionKind(),
+        Captures, CaptureInits, NewDecl,
+        const_cast<RecordDecl *>(AStmt->getCapturedRecordDecl()));
+
+    return NewStmt;
+  }
+  return static_cast<Stmt *>(NULL);
+}
+
+StmtResult SemaOpenMP::transformDispatchDirective(
+    OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
+    OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
+    Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) {
+
+  StmtResult RetValue;
+  std::vector<OMPClause *> DependVector;
+  for (auto C :
+       OMPExecutableDirective::getClausesOfKind<OMPDependClause>(Clauses)) {
+    auto C1 = const_cast<OMPDependClause *>(C);
+    if (OMPDependClause *DependClause = dyn_cast<OMPDependClause>(C1)) {
+      DependVector.push_back(DependClause);
+    }
+  }
+  llvm::ArrayRef<OMPClause *> DependCArray(DependVector);
+
+  // #pragma omp dispatch depend() is changed to #pragma omp taskwait depend()
+  // This is done by calling ActOnOpenMPExecutableDirective() for the
+  // new taskwait directive.
+  StmtResult DispatchDepend2taskwait =
+      ActOnOpenMPExecutableDirective(OMPD_taskwait, DirName, CancelRegion,
+                                     DependCArray, NULL, StartLoc, EndLoc);
+
+  if (OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses)) {
+
+    if (OMPExecutableDirective::getSingleClause<OMPNocontextClause>(Clauses)) {
+      Diag(StartLoc, diag::warn_omp_dispatch_clause_novariants_nocontext);
+    }
+
+    const OMPNovariantsClause *NoVariantsC =
+        OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses);
+    // #pragma omp dispatch novariants(c2) depend(out: x)
+    // foo();
+    // becomes:
+    // #pragma omp taskwait depend(out: x)
+    // if (c2) {
+    //    foo();
+    // } else {
+    //    #pragma omp dispatch
+    //    foo(); <--- foo() is replaced with foo_variant() in CodeGen
+    // }
+    Expr *Cond = getInitialExprFromCapturedExpr(NoVariantsC->getCondition());
+    StmtResult ThenStmt =
+        CloneNewCapturedStmtForDirectCall(getASTContext(), AStmt, this, false);
+    SmallVector<OMPClause *, 5> DependClauses;
+    StmtResult ElseStmt = ActOnOpenMPExecutableDirective(
+        Kind, DirName, CancelRegion, DependClauses, AStmt, StartLoc, EndLoc);
+    IfStmt *IfElseStmt =
+        IfStmt::Create(getASTContext(), StartLoc, IfStatementKind::Ordinary,
+                       nullptr, // Init
+                       nullptr, // Condition Var Declaration
+                       Cond,
+                       StartLoc, // Source Location Left Paranthesis
+                       StartLoc, // Source Location Right Paranthesis
+                       ThenStmt.get(), StartLoc, ElseStmt.get());
+    if (hasClausesOfKind<OMPDependClause>(Clauses)) {
+      StmtResult FinalStmts = combine2Stmts(
+          getASTContext(), DispatchDepend2taskwait.get(), IfElseStmt);
+      RetValue = FinalStmts;
+    } else {
+      RetValue = IfElseStmt;
+    }
+  } else if (OMPExecutableDirective::getSingleClause<OMPNocontextClause>(
+                 Clauses)) {
+
+    const OMPNocontextClause *NoContextC =
+        OMPExecutableDirective::getSingleClause<OMPNocontextClause>(Clauses);
+    Expr *Cond = getInitialExprFromCapturedExpr(NoContextC->getCondition());
+    // #pragma omp dispatch depend(out: x) nocontext(c2)
+    // foo();
+    // becomes:
+    // #pragma omp taskwait depend(out: x)
+    // if (c2) {
+    //    foo();
+    // } else {
+    //    #pragma omp dispatch
+    //    foo();
+    // }
+
+    StmtResult ThenStmt =
+        CloneNewCapturedStmtForDirectCall(getASTContext(), AStmt, this, true);
+
+    SmallVector<OMPClause *, 5> DependClauses;
+    StmtResult ElseStmt = ActOnOpenMPExecutableDirective(
+        Kind, DirName, CancelRegion, DependClauses, AStmt, StartLoc, EndLoc);
+    IfStmt *IfElseStmt =
+        IfStmt::Create(getASTContext(), StartLoc, IfStatementKind::Ordinary,
+                       nullptr, // Init
+                       nullptr, // Condition Var Declaration
+                       Cond,
+                       StartLoc, // Source Location Left Paranthesis
+                       StartLoc, // Source Location Right Paranthesis
+                       ThenStmt.get(), StartLoc, ElseStmt.get());
+    if (hasClausesOfKind<OMPDependClause>(Clauses)) {
+      StmtResult FinalStmts = combine2Stmts(
+          getASTContext(), DispatchDepend2taskwait.get(), IfElseStmt);
+      RetValue = FinalStmts;
+    } else {
+      RetValue = IfElseStmt;
+    }
+  } else if (hasClausesOfKind<OMPDependClause>(Clauses)) {
+    // Only:
+    // #pragma omp dispatch depend(out: x)
+    // foo();
+    // to
+    // #pragma omp taskwait depend(out: x)
+    // foo();
+    // StmtResult NewAStmt =
+    // CloneNewCapturedStmtForDirectCall(getASTContext(), AStmt, false);
+    StmtResult FinalStmts =
+        combine2Stmts(getASTContext(), DispatchDepend2taskwait.get(), AStmt);
+    RetValue = FinalStmts;
+  }
+  return RetValue;
+}
+
 StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
     OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
     OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
@@ -5979,6 +6219,20 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
           OMPExecutableDirective::getSingleClause<OMPBindClause>(Clauses))
     BindKind = BC->getBindKind();
 
+  if ((Kind == OMPD_dispatch) && (Clauses.size() > 0)) {
+
+    bool UnSupportedClause = false;
+    for (OMPClause *C : Clauses) {
+      if (!((C->getClauseKind() == OMPC_novariants) ||
+            (C->getClauseKind() == OMPC_nocontext) ||
+            (C->getClauseKind() == OMPC_depend)))
+        UnSupportedClause = true;
+    }
+    if (!UnSupportedClause)
+      return transformDispatchDirective(Kind, DirName, CancelRegion, Clauses,
+                                        AStmt, StartLoc, EndLoc);
+  }
+
   if (Kind == OMPD_loop && BindKind == OMPC_BIND_unknown) {
     const OpenMPDirectiveKind ParentDirective = DSAStack->getParentDirective();
 
@@ -7209,8 +7463,10 @@ ExprResult SemaOpenMP::ActOnOpenMPCall(ExprResult Call, Scope *Scope,
     Exprs.erase(Exprs.begin() + BestIdx);
   } while (!VMIs.empty());
 
-  if (!NewCall.isUsable())
+  if (!NewCall.isUsable()) {
+    fprintf(stdout, "Returning Call, NewCall is not usable\n");
     return Call;
+  }
   return PseudoObjectExpr::Create(getASTContext(), CE, {NewCall.get()}, 0);
 }
 
@@ -10520,8 +10776,17 @@ StmtResult SemaOpenMP::ActOnOpenMPSectionDirective(Stmt *AStmt,
                                      DSAStack->isCancelRegion());
 }
 
+static Expr *removePseudoObjectExpr(Expr *E) {
+  // PseudoObjectExpr is a Trait for dispatch containing the
+  // function name and its variant. Returning the function name.
+  if (auto *PO = dyn_cast<PseudoObjectExpr>(E))
+    E = *((PO->semantics_begin()) - 1);
+  return E;
+}
+
 static Expr *getDirectCallExpr(Expr *E) {
   E = E->IgnoreParenCasts()->IgnoreImplicit();
+  E = removePseudoObjectExpr(E);
   if (auto *CE = dyn_cast<CallExpr>(E))
     if (CE->getDirectCallee())
       return E;
@@ -10556,15 +10821,19 @@ SemaOpenMP::ActOnOpenMPDispatchDirective(ArrayRef<OMPClause *> Clauses,
     E = E->IgnoreParenCasts()->IgnoreImplicit();
 
     if (auto *BO = dyn_cast<BinaryOperator>(E)) {
-      if (BO->getOpcode() == BO_Assign)
+      if (BO->getOpcode() == BO_Assign) {
         TargetCall = getDirectCallExpr(BO->getRHS());
+      }
     } else {
       if (auto *COCE = dyn_cast<CXXOperatorCallExpr>(E))
         if (COCE->getOperator() == OO_Equal)
           TargetCall = getDirectCallExpr(COCE->getArg(1));
-      if (!TargetCall)
+      if (!TargetCall) {
+        E = removePseudoObjectExpr(E);
         TargetCall = getDirectCallExpr(E);
+      }
     }
+
     if (!TargetCall) {
       Diag(E->getBeginLoc(), diag::err_omp_dispatch_statement_call);
       return StmtError();
diff --git a/clang/test/Misc/warning-flags.c b/clang/test/Misc/warning-flags.c
index 1fd02440833359..5f3b54d7be8e29 100644
--- a/clang/test/Misc/warning-flags.c
+++ b/clang/test/Misc/warning-flags.c
@@ -18,7 +18,7 @@ This test serves two purposes:
 
 The list of warnings below should NEVER grow.  It should gradually shrink to 0.
 
-CHECK: Warnings without flags (61):
+CHECK: Warnings without flags (62):
 
 CHECK-NEXT:   ext_expected_semi_decl_list
 CHECK-NEXT:   ext_missing_whitespace_after_macro_name
@@ -65,6 +65,7 @@ CHECK-NEXT:   warn_no_constructor_for_refconst
 CHECK-NEXT:   warn_not_compound_assign
 CHECK-NEXT:   warn_objc_property_copy_missing_on_block
 CHECK-NEXT:   warn_objc_protocol_qualifier_missing_id
+CHECK-NEXT:   warn_omp_dispatch_clause_novariants_nocontext
 CHECK-NEXT:   warn_on_superclass_use
 CHECK-NEXT:   warn_pp_convert_to_positive
 CHECK-NEXT:   warn_pp_expr_overflow
diff --git a/clang/test/OpenMP/dispatch_codegen.cpp b/clang/test/OpenMP/dispatch_codegen.cpp
new file mode 100644
index 00000000000000..3a50ccbb99873a
--- /dev/null
+++ b/clang/test/OpenMP/dispatch_codegen.cpp
@@ -0,0 +1,359 @@
+// expected-no-diagnostics
+// RUN: %clang_cc1 -DCK1 -verify -fopenmp -x c++ %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple x86_64-unknown-unknown -emit-llvm %s -o - | FileCheck %s
+
+int foo_variant_dispatch(int x, int y) {
+   return x+2;
+}
+
+int foo_variant_allCond(int x, int y) {
+   return x+3;
+}
+
+#pragma omp declare variant(foo_v...
[truncated]

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, reformat your changes, follow coding standard, use nullptr instead of NULL, etc.

@SunilKuravinakop
Copy link
Contributor Author

I have addressed all the previous comments. Can you please take a look?

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use force-update, must use stacked PRs

@shiltian shiltian changed the title Support for dispatch construct (Sema & Codegen) support. [Clang][OpenMP] Support for dispatch construct (Sema & Codegen) support Dec 10, 2024
@SunilKuravinakop
Copy link
Contributor Author

SunilKuravinakop commented Dec 12, 2024

I ran ./install/bin/clang-tidy -p=build clang/lib/Sema/SemaOpenMP.cpp -checks=llvm-* and did not find any more changes being suggested for the code changes which I have done. Is there any other tool other than clang-tidy which I can run before uploading? I have been using git-clang-format for quite sometime now.

Copy link
Contributor

@shiltian shiltian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the openmp doc and clang release note as well?

Copy link
Contributor

@alexey-bataev What do you think? I'm not very familiar with some front end concepts used in this PR.

@SunilKuravinakop
Copy link
Contributor Author

update the openmp doc and clang release note as well?

Done.

Comment on lines 6070 to 6075
// if (cond_true) {
// foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
// } else {
// #pragma omp dispatch
// foo(i,j) // with traits: CodeGen call to foo_variant_dispatch(i,j)
// }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you cannot do this directly in codegen?

Copy link
Contributor Author

@SunilKuravinakop SunilKuravinakop Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for the detailed explanation below:

For a code like

#pragma omp dispatch nocontext(cond) depend(in:i)
foo(i,j)

foo(i,j) is present in PseudoObjectExpr, containing call to foo(i,j) and foo_variant_dispatch(i,j).
The word traits used below is in relation to PseudoObjectExpr.

The Sema transformation can be seen as:

#pragma omp taskwait depend(in:i)
foo(i,j) --> with OMPDispatchNoContextAttr

In the transformation in Sema, instead of adding OMPDispatchNoContextAttr we change it to:

if (cond) {
  foo(i,j) --> with traits containing call to foo_variant_allCond(i,j)
} else {
  #pragma omp dispatch
  foo(i,j) --> without any changes i.e. with traits containing 
                   call to foo_variant_dispatch(i,j)
}

in the then part, to change the function call from

    foo(i,j) --> with traits containing call to foo_variant_dispatch(i,j)
to foo(i,j) --> with traits containing call to foo_variant_allCond(i,j)
  1. With, foo(i,j) (i.e. direct call without traits)
  2. call ActOnOpenMPCall(). The ActOnOpenMPCall() will add the traits for call to foo_variant_allCond(i,j).
  3. In the CodeGen, the call to foo_variant_allCond(i,j) is emit-ed without any changes.

To do the same steps in CodeGen:

  1. For every function call I need to check if it has OMPDispatchNoContextAttr.
  2. Create a then part,
    a) With foo(i,j) (i.e. direct call) Either call ActOnOpenMPCall() to get
    call to foo_variant_dispatch(i,j) or replicate the steps from 7439 to 7486.
    b) Create if-else statement as mentioned above for Sema, without #pragma omp dispatch in else-part.
    c) Emit the if-else statement.

Doing it in Sema is easier because functions, like ActOnOpenMPCall(), generating of PseudoObjectExpr are present in Sema itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sema should generate only required statements/expressions. If something can be emitted in codegen, better to emit in codegen, to avoid extra stuff in AST

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With your suggestion, Source code statement of the form:

#pragma omp dispatch depend(out: x) nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)

Will result in
Changes in Sema:

#pragma omp taskwait depend(out:x)
#pragma omp dispatch nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)

Changes in CodeGen method, CodeGenFunction::EmitOMPDispatchDirective:

  1. Create an IfElseStmt containing:
      if (c2) {
        foo(); --> with traits call to foo_variant_allCond(i,j)
      } else {
        foo(); --> with traits call to foo_variant_dispatch(i,j)
      }
  1. EmitStmt(IfElseStmt);

Is there a better way of doing it in CodeGen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume all these helpers can be directly codegened instead of building them in sema

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some extra additions. Sema should not produce extra OpenMP constructs. You can generate required CaturedStmt (one for outer taskwait and one for inner dispatch) and then use them to generate in codegen. Sema should be as close to the source code as possible.

Copy link
Contributor Author

@SunilKuravinakop SunilKuravinakop Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume all these helpers can be directly codegened instead of building them in sema

I have moved the helpers to codegen. I have done this by adding AnnotateAttr in Sema and made use of AnnotateAttr in the helpers in Codegen. Can you please review and give me the feedback?

There are some extra additions. Sema should not produce extra OpenMP constructs. You can generate required CaturedStmt (one for outer taskwait and one for inner dispatch) and then use them to generate in codegen. Sema should be as close to the source code as possible.

Do you want me to revert back the helpers from CodeGen back to Sema? Do you want me to generate the CapturedStmts in Sema itself? The feedback that you have provided in the above 2 comments are confusing to me. Can you please review the latest push to this PR and give the feedback?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if dispatch transformed to

#pragma omp taskwait depend(out:x)
#pragma omp dispatch nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)

you should have 2 captured regions - one for taskwait and one for dispatch. It does not mean, you need to transform the AST node, there should be single AST node with 2 captured regions (one caotures another), just like it is done for combined constructs.

@jtb20
Copy link
Contributor

jtb20 commented Feb 12, 2025

Hi!

FYI -- I have a partial implementation of semantic & codegen support for "dispatch" directives also. I've created a draft PR for it here: #126914

Obviously there's no point in duplicating effort on this feature, but perhaps bits from my implementation might be helpful. I noticed a few OpenMP_VV tests fail with this patch that succeed with mine -- I'll update with further details shortly.

HTH!

@jtb20
Copy link
Contributor

jtb20 commented Feb 12, 2025

The failures with this patch and not mine are just:

PASS->FAIL: tests/5.1/dispatch/test_dispatch_is_device_ptr.c run
PASS->FAIL: tests/5.1/dispatch/test_dispatch_nowait.c run

that's actually probably neither here nor there, the clauses in question aren't implemented in my patch yet so the tests are likely succeeding "by accident".

@jtb20
Copy link
Contributor

jtb20 commented Feb 12, 2025

Here's a test using templates that segfaults with this patchset as-is:

https://gist.github.com/jtb20/dfa95abf2658da758797803d4b949f30

This resembles an OpenMP_VV test, and is using the ompvv.h header from the official tests, but is just a hacked-up thing I created myself, so YMMV.

IIRC that one isn't quite working with my patch either, yet. Again, just FYI.

@SunilKuravinakop
Copy link
Contributor Author

Hi!

FYI -- I have a partial implementation of semantic & codegen support for "dispatch" directives also. I've created a draft PR for it here: #126914

Obviously there's no point in duplicating effort on this feature, but perhaps bits from my implementation might be helpful. I noticed a few OpenMP_VV tests fail with this patch that succeed with mine -- I'll update with further details shortly.

HTH!

Hello Julian,
Thank you for the help. I am going through your implementation to see how I can use it into mine. Thanks a lot.
I am also waiting for Alexey to reply back to my changes.
--Sunil

@@ -4444,6 +4444,105 @@ void CodeGenFunction::EmitOMPMasterDirective(const OMPMasterDirective &S) {
emitMaster(*this, S);
}

static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has changed. I do not have this any more in SemaOpenMP.cpp. The aim in SemaOpenMP.cpp is to add an AnnotateAttr. Now, I have written
void SemaOpenMP::annotateAStmt()
static Expr *getNewTraitsOrDirectCall()

Comment on lines 6070 to 6075
// if (cond_true) {
// foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
// } else {
// #pragma omp dispatch
// foo(i,j) // with traits: CodeGen call to foo_variant_dispatch(i,j)
// }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if dispatch transformed to

#pragma omp taskwait depend(out:x)
#pragma omp dispatch nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)

you should have 2 captured regions - one for taskwait and one for dispatch. It does not mean, you need to transform the AST node, there should be single AST node with 2 captured regions (one caotures another), just like it is done for combined constructs.

@SunilKuravinakop
Copy link
Contributor Author

SunilKuravinakop commented Feb 12, 2025

#if dispatch transformed to

#pragma omp taskwait depend(out:x)
#pragma omp dispatch nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)

you should have 2 captured regions - one for taskwait and one for dispatch. It does not mean, >you need to transform the AST node, there should be single AST node with 2 captured regions >(one caotures another), just like it is done for combined constructs.

I do not have this code anymore in SemaOpenMP.cpp. I am having code for adding AnnotateAttr. In CodeGen this exact comment has been taken care of.

@alexey-bataev
Copy link
Member

#if dispatch transformed to
#pragma omp taskwait depend(out:x)
#pragma omp dispatch nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)
you should have 2 captured regions - one for taskwait and one for dispatch. It does not mean, >you need to transform the AST node, there should be single AST node with 2 captured regions >(one caotures another), just like it is done for combined constructs.

I do not have this code anymore in SemaOpenMP.cpp. I am having code for adding AnnotateAttr. In CodeGen this exact comment has been taken care of.

I'm saying how it should be implemented. Otherwise, it may crash in some cases or work incorrectly.

@SunilKuravinakop
Copy link
Contributor Author

#if dispatch transformed to
#pragma omp taskwait depend(out:x)
#pragma omp dispatch nocontext(c2)
foo(); --> with traits call to foo_variant_dispatch(i,j)
you should have 2 captured regions - one for taskwait and one for dispatch. It does not mean, >you need to transform the AST node, there should be single AST node with 2 captured regions >(one caotures another), just like it is done for combined constructs.

I do not have this code anymore in SemaOpenMP.cpp. I am having code for adding AnnotateAttr. In CodeGen this exact comment has been taken care of.

I'm saying how it should be implemented. Otherwise, it may crash in some cases or work incorrectly.

Do you want me to move the changes from CodeGen back into SemaOpenMP.cpp and avoid AnnotateAttr? With changes that I have now, I was looking at:

  1. With current changes there was a crash occurring with constant values e.g. nocontext(true) This was solved easily.
  2. The changes for handling template was difficult. I could not get the variant method in LLVM IR (-emit-llvm). Probably I am missing something in VisitCXXMethodDecl().

@SunilKuravinakop
Copy link
Contributor Author

This PR was closed probably because I did a push my changes into another wrong branch of my own. A new PR has been created in #131838.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:openmp
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants