diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h index 048008a8d80c7..b2c3d92909375 100644 --- a/flang/include/flang/Parser/dump-parse-tree.h +++ b/flang/include/flang/Parser/dump-parse-tree.h @@ -233,6 +233,7 @@ class ParseTreeDumper { NODE(parser, CriticalStmt) NODE(parser, CUDAAttributesStmt) NODE(parser, CUFKernelDoConstruct) + NODE(CUFKernelDoConstruct, StarOrExpr) NODE(CUFKernelDoConstruct, Directive) NODE(parser, CycleStmt) NODE(parser, DataComponentDefStmt) diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index f7b72c3af0916..c96abfba491d4 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -4297,16 +4297,18 @@ struct OpenACCConstruct { // CUF-kernel-do-construct -> // !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream] // >>> do-construct -// grid -> * | scalar-int-expr | ( scalar-int-expr-list ) -// block -> * | scalar-int-expr | ( scalar-int-expr-list ) +// star-or-expr -> * | scalar-int-expr +// grid -> * | scalar-int-expr | ( star-or-expr-list ) +// block -> * | scalar-int-expr | ( star-or-expr-list ) // stream -> 0, scalar-int-expr | STREAM = scalar-int-expr struct CUFKernelDoConstruct { TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct); + WRAPPER_CLASS(StarOrExpr, std::optional); struct Directive { TUPLE_CLASS_BOILERPLATE(Directive); CharBlock source; - std::tuple, std::list, - std::list, std::optional> + std::tuple, std::list, + std::list, std::optional> t; }; std::tuple> t; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index a668ba4116faa..e6511e0c61c8c 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2508,19 +2508,32 @@ class FirConverter : public Fortran::lower::AbstractConverter { if (nestedLoops > 1) n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops); - const std::list &grid = std::get<1>(dir.t); - const std::list &block = std::get<2>(dir.t); + const std::list &grid = + std::get<1>(dir.t); + const std::list &block = + std::get<2>(dir.t); const std::optional &stream = std::get<3>(dir.t); llvm::SmallVector gridValues; - for (const Fortran::parser::ScalarIntExpr &expr : grid) - gridValues.push_back(fir::getBase( - genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx))); + for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : grid) { + if (expr.v) { + gridValues.push_back(fir::getBase( + genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx))); + } else { + // TODO: '*' + } + } llvm::SmallVector blockValues; - for (const Fortran::parser::ScalarIntExpr &expr : block) - blockValues.push_back(fir::getBase( - genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx))); + for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : + block) { + if (expr.v) { + blockValues.push_back(fir::getBase( + genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx))); + } else { + // TODO: '*' + } + } mlir::Value streamValue; if (stream) streamValue = fir::getBase( diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp index de2be017508c3..07a570bd61e99 100644 --- a/flang/lib/Parser/executable-parsers.cpp +++ b/flang/lib/Parser/executable-parsers.cpp @@ -542,19 +542,19 @@ TYPE_CONTEXT_PARSER("UNLOCK statement"_en_US, // CUF-kernel-do-directive -> // !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream] // >>> do-construct -// grid -> * | scalar-int-expr | ( scalar-int-expr-list ) -// block -> * | scalar-int-expr | ( scalar-int-expr-list ) +// star-or-expr -> * | scalar-int-expr +// grid -> * | scalar-int-expr | ( star-or-expr-list ) +// block -> * | scalar-int-expr | ( star-or-expr-list ) // stream -> ( 0, | STREAM = ) scalar-int-expr +constexpr auto starOrExpr{construct( + "*" >> pure>() || + applyFunction(presentOptional, scalarIntExpr))}; +constexpr auto gridOrBlock{parenthesized(nonemptyList(starOrExpr)) || + applyFunction(singletonList, starOrExpr)}; TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >> construct( - maybe(parenthesized(scalarIntConstantExpr)), - "<<<" >> - ("*" >> pure>() || - parenthesized(nonemptyList(scalarIntExpr)) || - applyFunction(singletonList, scalarIntExpr)), - "," >> ("*" >> pure>() || - parenthesized(nonemptyList(scalarIntExpr)) || - applyFunction(singletonList, scalarIntExpr)), + maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock, + "," >> gridOrBlock, maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>" / endDirective))) TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US, diff --git a/flang/lib/Parser/misc-parsers.h b/flang/lib/Parser/misc-parsers.h index e9b52b7d0fcd0..4a318e05bb4b8 100644 --- a/flang/lib/Parser/misc-parsers.h +++ b/flang/lib/Parser/misc-parsers.h @@ -57,5 +57,10 @@ template common::IfNoLvalue, A> singletonList(A &&x) { result.emplace_back(std::move(x)); return result; } + +template +common::IfNoLvalue, A> presentOptional(A &&x) { + return std::make_optional(std::move(x)); +} } // namespace Fortran::parser #endif diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 600aa01999dab..baba4863f5775 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2729,6 +2729,13 @@ class UnparseVisitor { WALK_NESTED_ENUM(OmpOrderModifier, Kind) // OMP order-modifier #undef WALK_NESTED_ENUM + void Unparse(const CUFKernelDoConstruct::StarOrExpr &x) { + if (x.v) { + Walk(*x.v); + } else { + Word("*"); + } + } void Unparse(const CUFKernelDoConstruct::Directive &x) { Word("!$CUF KERNEL DO"); Walk(" (", std::get>(x.t), ")"); diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf index db628fe756b95..c017561447f85 100644 --- a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf +++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf @@ -42,10 +42,7 @@ subroutine sub1() ! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) ! CHECK: {n = 2 : i64} -! TODO: currently these trigger error in the parser +! TODO: lowering for these cases ! !$cuf kernel do(2) <<< (1,*), (256,1) >>> ! !$cuf kernel do(2) <<< (*,*), (32,4) >>> end - - - diff --git a/flang/test/Parser/cuf-sanity-tree.CUF b/flang/test/Parser/cuf-sanity-tree.CUF index f6cf9bbdd6b0c..dc12759d3ce52 100644 --- a/flang/test/Parser/cuf-sanity-tree.CUF +++ b/flang/test/Parser/cuf-sanity-tree.CUF @@ -144,11 +144,11 @@ include "cuf-sanity-common" !CHECK: | | | | | | EndDoStmt -> !CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> CUFKernelDoConstruct !CHECK: | | | | | Directive -!CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4' +!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '1_4' !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1' -!CHECK: | | | | | | Scalar -> Integer -> Expr = '2_4' +!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '2_4' !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2' -!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4' +!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '3_4' !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3' !CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4' !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1' diff --git a/flang/test/Semantics/cuf09.cuf b/flang/test/Semantics/cuf09.cuf index dd70c3b1ff5ef..4bc93132044fd 100644 --- a/flang/test/Semantics/cuf09.cuf +++ b/flang/test/Semantics/cuf09.cuf @@ -10,6 +10,15 @@ module m end program main + !$cuf kernel do <<< *, * >>> ! ok + do j = 1, 0 + end do + !$cuf kernel do <<< (*), (*) >>> ! ok + do j = 1, 0 + end do + !$cuf kernel do <<< (1,*), (2,*) >>> ! ok + do j = 1, 0 + end do !ERROR: !$CUF KERNEL DO (1) must be followed by a DO construct with tightly nested outer levels of counted DO loops !$cuf kernel do <<< 1, 2 >>> do while (.false.)