Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Parse !$CUF KERNEL DO <<< (*) #85338

Merged
merged 1 commit into from
Mar 15, 2024
Merged

[flang] Parse !$CUF KERNEL DO <<< (*) #85338

merged 1 commit into from
Mar 15, 2024

Conversation

klausler
Copy link
Contributor

Accept and represent asterisks within the parenthesized grid and block specification lists.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:semantics flang:parser labels Mar 14, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-flang-parser
@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-fir-hlfir

Author: Peter Klausler (klausler)

Changes

Accept and represent asterisks within the parenthesized grid and block specification lists.


Full diff: https://github.com/llvm/llvm-project/pull/85338.diff

9 Files Affected:

  • (modified) flang/include/flang/Parser/dump-parse-tree.h (+1)
  • (modified) flang/include/flang/Parser/parse-tree.h (+6-4)
  • (modified) flang/lib/Lower/Bridge.cpp (+21-8)
  • (modified) flang/lib/Parser/executable-parsers.cpp (+10-10)
  • (modified) flang/lib/Parser/misc-parsers.h (+5)
  • (modified) flang/lib/Parser/unparse.cpp (+7)
  • (modified) flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf (+1-4)
  • (modified) flang/test/Parser/cuf-sanity-tree.CUF (+3-3)
  • (modified) flang/test/Semantics/cuf09.cuf (+9)
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 048008a8d80c79..b2c3d92909375c 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -233,6 +233,7 @@ class ParseTreeDumper {
   NODE(parser, CriticalStmt)
   NODE(parser, CUDAAttributesStmt)
   NODE(parser, CUFKernelDoConstruct)
+  NODE(CUFKernelDoConstruct, StarOrExpr)
   NODE(CUFKernelDoConstruct, Directive)
   NODE(parser, CycleStmt)
   NODE(parser, DataComponentDefStmt)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index f7b72c3af09164..c96abfba491d4b 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4297,16 +4297,18 @@ struct OpenACCConstruct {
 // CUF-kernel-do-construct ->
 //     !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream]
 //     >>> do-construct
-// grid -> * | scalar-int-expr | ( scalar-int-expr-list )
-// block -> * | scalar-int-expr | ( scalar-int-expr-list )
+// star-or-expr -> * | scalar-int-expr
+// grid -> * | scalar-int-expr | ( star-or-expr-list )
+// block -> * | scalar-int-expr | ( star-or-expr-list )
 // stream -> 0, scalar-int-expr | STREAM = scalar-int-expr
 struct CUFKernelDoConstruct {
   TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
+  WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
   struct Directive {
     TUPLE_CLASS_BOILERPLATE(Directive);
     CharBlock source;
-    std::tuple<std::optional<ScalarIntConstantExpr>, std::list<ScalarIntExpr>,
-        std::list<ScalarIntExpr>, std::optional<ScalarIntExpr>>
+    std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
+        std::list<StarOrExpr>, std::optional<ScalarIntExpr>>
         t;
   };
   std::tuple<Directive, std::optional<DoConstruct>> t;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a668ba4116faab..e6511e0c61c8c3 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2508,19 +2508,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (nestedLoops > 1)
       n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
 
-    const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
-    const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+    const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
+        std::get<1>(dir.t);
+    const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
+        std::get<2>(dir.t);
     const std::optional<Fortran::parser::ScalarIntExpr> &stream =
         std::get<3>(dir.t);
 
     llvm::SmallVector<mlir::Value> gridValues;
-    for (const Fortran::parser::ScalarIntExpr &expr : grid)
-      gridValues.push_back(fir::getBase(
-          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : grid) {
+      if (expr.v) {
+        gridValues.push_back(fir::getBase(
+            genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+      } else {
+        // TODO: '*'
+      }
+    }
     llvm::SmallVector<mlir::Value> blockValues;
-    for (const Fortran::parser::ScalarIntExpr &expr : block)
-      blockValues.push_back(fir::getBase(
-          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+         block) {
+      if (expr.v) {
+        blockValues.push_back(fir::getBase(
+            genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+      } else {
+        // TODO: '*'
+      }
+    }
     mlir::Value streamValue;
     if (stream)
       streamValue = fir::getBase(
diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp
index de2be017508c37..6f258f131c3c22 100644
--- a/flang/lib/Parser/executable-parsers.cpp
+++ b/flang/lib/Parser/executable-parsers.cpp
@@ -542,19 +542,19 @@ TYPE_CONTEXT_PARSER("UNLOCK statement"_en_US,
 // CUF-kernel-do-directive ->
 //     !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream]
 //     >>> do-construct
-// grid -> * | scalar-int-expr | ( scalar-int-expr-list )
-// block -> * | scalar-int-expr | ( scalar-int-expr-list )
+// star-or-expr -> * | scalar-int-expr
+// grid -> * | scalar-int-expr | ( star-or-expr-list )
+// block -> * | scalar-int-expr | ( star-or-expr-list )
 // stream -> ( 0, | STREAM = ) scalar-int-expr
+const auto starOrExpr{construct<CUFKernelDoConstruct::StarOrExpr>(
+    "*" >> pure<std::optional<ScalarIntExpr>>() ||
+    applyFunction(presentOptional<ScalarIntExpr>, scalarIntExpr))};
+constexpr auto gridOrBlock{parenthesized(nonemptyList(starOrExpr)) ||
+    applyFunction(singletonList<CUFKernelDoConstruct::StarOrExpr>, starOrExpr)};
 TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
     construct<CUFKernelDoConstruct::Directive>(
-        maybe(parenthesized(scalarIntConstantExpr)),
-        "<<<" >>
-            ("*" >> pure<std::list<ScalarIntExpr>>() ||
-                parenthesized(nonemptyList(scalarIntExpr)) ||
-                applyFunction(singletonList<ScalarIntExpr>, scalarIntExpr)),
-        "," >> ("*" >> pure<std::list<ScalarIntExpr>>() ||
-                   parenthesized(nonemptyList(scalarIntExpr)) ||
-                   applyFunction(singletonList<ScalarIntExpr>, scalarIntExpr)),
+        maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
+        "," >> gridOrBlock,
         maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>" /
             endDirective)))
 TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
diff --git a/flang/lib/Parser/misc-parsers.h b/flang/lib/Parser/misc-parsers.h
index e9b52b7d0fcd0f..4a318e05bb4b8c 100644
--- a/flang/lib/Parser/misc-parsers.h
+++ b/flang/lib/Parser/misc-parsers.h
@@ -57,5 +57,10 @@ template <typename A> common::IfNoLvalue<std::list<A>, A> singletonList(A &&x) {
   result.emplace_back(std::move(x));
   return result;
 }
+
+template <typename A>
+common::IfNoLvalue<std::optional<A>, A> presentOptional(A &&x) {
+  return std::make_optional(std::move(x));
+}
 } // namespace Fortran::parser
 #endif
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 600aa01999dab7..baba4863f5775f 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2729,6 +2729,13 @@ class UnparseVisitor {
   WALK_NESTED_ENUM(OmpOrderModifier, Kind) // OMP order-modifier
 #undef WALK_NESTED_ENUM
 
+  void Unparse(const CUFKernelDoConstruct::StarOrExpr &x) {
+    if (x.v) {
+      Walk(*x.v);
+    } else {
+      Word("*");
+    }
+  }
   void Unparse(const CUFKernelDoConstruct::Directive &x) {
     Word("!$CUF KERNEL DO");
     Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
index db628fe756b952..c017561447f85d 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -42,10 +42,7 @@ subroutine sub1()
 ! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
 ! CHECK: {n = 2 : i64}
 
-! TODO: currently these trigger error in the parser
+! TODO: lowering for these cases
 ! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
 ! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
 end
-
-
-
diff --git a/flang/test/Parser/cuf-sanity-tree.CUF b/flang/test/Parser/cuf-sanity-tree.CUF
index f6cf9bbdd6b0cc..dc12759d3ce52f 100644
--- a/flang/test/Parser/cuf-sanity-tree.CUF
+++ b/flang/test/Parser/cuf-sanity-tree.CUF
@@ -144,11 +144,11 @@ include "cuf-sanity-common"
 !CHECK: | | | | | | EndDoStmt -> 
 !CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> CUFKernelDoConstruct
 !CHECK: | | | | | Directive
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '2_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '2_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '3_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
 !CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
diff --git a/flang/test/Semantics/cuf09.cuf b/flang/test/Semantics/cuf09.cuf
index dd70c3b1ff5efd..4bc93132044fdd 100644
--- a/flang/test/Semantics/cuf09.cuf
+++ b/flang/test/Semantics/cuf09.cuf
@@ -10,6 +10,15 @@ module m
 end
 
 program main
+  !$cuf kernel do <<< *, * >>> ! ok
+  do j = 1, 0
+  end do
+  !$cuf kernel do <<< (*), (*) >>> ! ok
+  do j = 1, 0
+  end do
+  !$cuf kernel do <<< (1,*), (2,*) >>> ! ok
+  do j = 1, 0
+  end do
   !ERROR: !$CUF KERNEL DO (1) must be followed by a DO construct with tightly nested outer levels of counted DO loops
   !$cuf kernel do <<< 1, 2 >>>
   do while (.false.)

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update on the parser. I'll take care of the update in the lowering part.

Accept and represent asterisks within the parenthesized grid
and block specification lists.
@klausler klausler merged commit 60fa2b0 into llvm:main Mar 15, 2024
3 of 4 checks passed
@klausler klausler deleted the bug1538 branch March 15, 2024 20:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:parser flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants