Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] CUF kernel loop directive #82836

Merged
merged 2 commits into from
Feb 27, 2024
Merged

Conversation

clementval
Copy link
Contributor

This patch introduces a new operation to represent the CUDA Fortran kernel loop directive. This operation is modeled as a LoopLikeOp operation in a similar way to acc.loop.

The CUFKernelDoConstruct parse tree node is also placed correctly in the PFTBuilder to be available in PFT evaluations.

Lowering from the flang parse-tree to MLIR is also done.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 23, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This patch introduces a new operation to represent the CUDA Fortran kernel loop directive. This operation is modeled as a LoopLikeOp operation in a similar way to acc.loop.

The CUFKernelDoConstruct parse tree node is also placed correctly in the PFTBuilder to be available in PFT evaluations.

Lowering from the flang parse-tree to MLIR is also done.


Full diff: https://github.com/llvm/llvm-project/pull/82836.diff

5 Files Affected:

  • (modified) flang/include/flang/Lower/PFTBuilder.h (+3-2)
  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+27)
  • (modified) flang/lib/Lower/Bridge.cpp (+121)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+98)
  • (added) flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf (+51)
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index c2b0fdbf357cde..9913f584133faa 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -138,7 +138,8 @@ using Directives =
     std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
                parser::OpenACCRoutineConstruct,
                parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
-               parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
+               parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective,
+               parser::CUFKernelDoConstruct>;
 
 using DeclConstructs = std::tuple<parser::OpenMPDeclarativeConstruct,
                                   parser::OpenACCDeclarativeConstruct>;
@@ -178,7 +179,7 @@ static constexpr bool isNopConstructStmt{common::HasMember<
 template <typename A>
 static constexpr bool isExecutableDirective{common::HasMember<
     A, std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
-                  parser::OpenMPConstruct>>};
+                  parser::OpenMPConstruct, parser::CUFKernelDoConstruct>>};
 
 template <typename A>
 static constexpr bool isFunctionLike{common::HasMember<
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 08239230f793f1..db5e5f4bc682e6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3127,4 +3127,31 @@ def fir_BoxOffsetOp : fir_Op<"box_offset", [NoMemoryEffect]> {
   ];
 }
 
+def fir_CUDAKernelOp : fir_Op<"cuda_kernel", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+
+  let arguments = (ins
+    Variadic<I32>:$grid, // empty means `*`
+    Variadic<I32>:$block, // empty means `*`
+    Optional<I32>:$stream,
+    Variadic<Index>:$lowerbound,
+    Variadic<Index>:$upperbound,
+    Variadic<Index>:$step,
+    OptionalAttr<I64Attr>:$n
+  );
+
+  let regions = (region AnyRegion:$region);
+
+  let assemblyFormat = [{
+    `<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,` 
+                custom<CUFKernelValues>($block, type($block))
+        ( `,` `stream` `=` $stream^ )? `>` `>` `>`
+        custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
+            $upperbound, type($upperbound), $step, type($step))
+        attr-dict
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2d7f748cefa2d8..2c4825fafdbee4 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2453,6 +2453,127 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
   }
 
+  void genFIR(const Fortran::parser::CUFKernelDoConstruct &kernel) {
+    localSymbols.pushScope();
+    const Fortran::parser::CUFKernelDoConstruct::Directive &dir =
+        std::get<Fortran::parser::CUFKernelDoConstruct::Directive>(kernel.t);
+
+    mlir::Location loc = genLocation(dir.source);
+
+    Fortran::lower::StatementContext stmtCtx;
+
+    unsigned nestedLoops = 1;
+    
+    const auto &nLoops =
+        std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(dir.t);
+    if (nLoops)
+      nestedLoops = *Fortran::semantics::GetIntValue(*nLoops);
+
+    mlir::IntegerAttr n;
+    if (nestedLoops > 1)
+      n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
+
+    const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
+    const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+    const std::optional<Fortran::parser::ScalarIntExpr> &stream = std::get<3>(dir.t);
+
+    llvm::SmallVector<mlir::Value> gridValues;
+    for (const Fortran::parser::ScalarIntExpr &expr : grid)
+      gridValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    llvm::SmallVector<mlir::Value> blockValues;
+    for (const Fortran::parser::ScalarIntExpr &expr : block)
+      blockValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    mlir::Value streamValue;
+    if (stream)
+      streamValue = fir::getBase(genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
+
+    const auto &outerDoConstruct =
+        std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
+
+    llvm::SmallVector<mlir::Location> locs;
+    locs.push_back(loc);
+    llvm::SmallVector<mlir::Value> lbs, ubs, steps;
+
+    mlir::Type idxTy = builder->getIndexType();
+
+    llvm::SmallVector<mlir::Type> ivTypes;
+    llvm::SmallVector<mlir::Location> ivLocs;
+    llvm::SmallVector<mlir::Value> ivValues;
+    for (unsigned i = 0; i < nestedLoops; ++i) {
+      const Fortran::parser::LoopControl *loopControl;
+      Fortran::lower::pft::Evaluation *loopEval = &getEval().getFirstNestedEvaluation();
+
+      mlir::Location crtLoc = loc;
+      if (i == 0) {
+        loopControl = &*outerDoConstruct->GetLoopControl();
+        crtLoc = genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
+      } else {
+        auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
+        assert(doCons && "expect do construct");
+        loopControl = &*doCons->GetLoopControl();
+        crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+      }
+
+      locs.push_back(crtLoc);
+
+      const Fortran::parser::LoopControl::Bounds *bounds =
+          std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+      assert(bounds && "Expected bounds on the loop construct");
+
+      Fortran::semantics::Symbol &ivSym =
+          bounds->name.thing.symbol->GetUltimate();
+      ivValues.push_back(getSymbolAddress(ivSym));
+
+      lbs.push_back(builder->createConvert(crtLoc, idxTy,
+          fir::getBase(genExprValue(
+          *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
+      ubs.push_back(builder->createConvert(crtLoc, idxTy,
+          fir::getBase(genExprValue(
+          *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
+      if (bounds->step)
+        steps.push_back(fir::getBase(genExprValue(
+            *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+      else // If `step` is not present, assume it is `1`.
+        steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
+      
+      ivTypes.push_back(idxTy);
+      ivLocs.push_back(crtLoc);
+      if (i < nestedLoops - 1)
+        loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+    }
+
+    auto op = builder->create<fir::CUDAKernelOp>(
+        loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n);
+    builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes, ivLocs);
+    mlir::Block &b = op.getRegion().back();
+    builder->setInsertionPointToStart(&b);
+
+    for (auto [arg, value] : llvm::zip(
+           op.getLoopRegions().front()->front().getArguments(), ivValues)) {
+      mlir::Value convArg = builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
+      builder->create<fir::StoreOp>(loc, convArg, value);
+    }
+
+    builder->create<fir::FirEndOp>(loc);
+    builder->setInsertionPointToStart(&b);
+
+    Fortran::lower::pft::Evaluation *crtEval = &getEval();
+    if (crtEval->lowerAsStructured()) {
+      crtEval = &crtEval->getFirstNestedEvaluation();
+      for (int64_t i = 1; i < nestedLoops; i++)
+        crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+    }
+
+
+
+    // Generate loop body
+    for (Fortran::lower::pft::Evaluation &e : crtEval->getNestedEvaluations())
+      genFIR(e);
+
+    builder->setInsertionPointAfter(op);
+    localSymbols.popScope();
+  }
+
   void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     genOpenMPConstruct(*this, localSymbols, bridge.getSemanticsContext(),
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 0a534cdb3c4871..c2facb5a004a8f 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3866,6 +3866,104 @@ mlir::LogicalResult fir::DeclareOp::verify() {
   return fortranVar.verifyDeclareLikeOpImpl(getMemref());
 }
 
+llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() {
+  return {&getRegion()};
+}
+
+mlir::ParseResult
+parseCUFKernelValues(mlir::OpAsmParser &parser,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
+                     llvm::SmallVectorImpl<mlir::Type> &types) {
+  if (mlir::succeeded(parser.parseOptionalStar()))
+    return mlir::success();
+  
+  if (parser.parseOptionalLParen()) {
+    if (mlir::failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(values.emplace_back()))
+                return mlir::failure();
+              return mlir::success();
+            })))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  } else {
+    if (parser.parseOperand(values.emplace_back()))
+      return mlir::failure();
+    return mlir::success();
+  }
+  return mlir::success();
+}
+
+void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                          mlir::ValueRange values, mlir::TypeRange types) {
+  if (values.empty())
+    p << "*";
+
+  if (values.size() > 1)
+    p << "(";
+  llvm::interleaveComma(values, p,
+                        [&p](mlir::Value v) { p << v; });
+  if (values.size() > 1)
+    p << ")";
+}
+
+mlir::ParseResult
+parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region &region,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
+    llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
+    llvm::SmallVectorImpl<mlir::Type> &upperboundType,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step,
+    llvm::SmallVectorImpl<mlir::Type> &stepType) {
+
+  llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
+  if (parser.parseLParen() ||
+      parser.parseArgumentList(inductionVars, mlir::OpAsmParser::Delimiter::None,
+                                /*allowType=*/true) ||
+      parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
+      parser.parseOperandList(lowerbound, inductionVars.size(),
+                              mlir::OpAsmParser::Delimiter::None) ||
+      parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
+      parser.parseKeyword("to") || parser.parseLParen() ||
+      parser.parseOperandList(upperbound, inductionVars.size(),
+                              mlir::OpAsmParser::Delimiter::None) ||
+      parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
+      parser.parseKeyword("step") || parser.parseLParen() ||
+      parser.parseOperandList(step, inductionVars.size(),
+                              mlir::OpAsmParser::Delimiter::None) ||
+      parser.parseColonTypeList(stepType) || parser.parseRParen())
+    return mlir::failure();
+  return parser.parseRegion(region, inductionVars);
+}
+
+void printCUFKernelLoopControl(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                      mlir::Region &region, mlir::ValueRange lowerbound,
+                      mlir::TypeRange lowerboundType,
+                      mlir::ValueRange upperbound,
+                      mlir::TypeRange upperboundType, mlir::ValueRange steps,
+                      mlir::TypeRange stepType) {
+  mlir::ValueRange regionArgs = region.front().getArguments();
+  if (!regionArgs.empty()) {
+    p << "(";
+    llvm::interleaveComma(regionArgs, p,
+                          [&p](mlir::Value v) { p << v << " : " << v.getType(); });
+    p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
+      << upperbound << " : " << upperboundType << ") "
+      << " step (" << steps << " : " << stepType << ") ";
+  }
+  p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
+mlir::LogicalResult fir::CUDAKernelOp::verify() {
+  if (getLowerbound().size() != getUpperbound().size() ||
+      getLowerbound().size() != getStep().size())
+    return emitOpError(
+        "expect same number of values in lowerbound, upperbound and step");
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // FIROpsDialect
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
new file mode 100644
index 00000000000000..db628fe756b952
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -0,0 +1,51 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Test lowering of CUDA kernel loop directive.
+
+subroutine sub1()
+  integer :: i, j
+  integer, parameter :: n = 100
+  real :: a(n), b(n)
+  real :: c(n,n), d(n,n)
+
+! CHECK-LABEL: func.func @_QPsub1()
+! CHECK: %[[IV:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  !$cuf kernel do <<< 1, 2 >>>
+  do i = 1, n
+    a(i) = a(i) * b(i)
+  end do
+
+! CHECK: %[[LB:.*]] = fir.convert %c1{{.*}} : (i32) -> index
+! CHECK: %[[UB:.*]] = fir.convert %c100{{.*}} : (i32) -> index
+! CHECK: %[[STEP:.*]] = arith.constant 1 : index
+! CHECK: fir.cuda_kernel<<<%c1_i32, %c2_i32>>> (%[[ARG0:.*]] : index) = (%[[LB]] : index) to (%[[UB]] : index)  step (%[[STEP]] : index)
+! CHECK-NOT: fir.do_loop
+! CHECK: %[[ARG0_I32:.*]] = fir.convert %[[ARG0]] : (index) -> i32
+! CHECK: fir.store %[[ARG0_I32]] to %[[IV]]#1 : !fir.ref<i32>
+
+
+  !$cuf kernel do <<< *, * >>>
+  do i = 1, n
+    a(i) = a(i) * b(i)
+  end do
+
+! CHECK: fir.cuda_kernel<<<*, *>>> (%{{.*}} : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index)
+
+  !$cuf kernel do(2) <<< 1, (256,1) >>>
+  do i = 1, n
+    do j = 1, n
+      c(i,j) = c(i,j) * d(i,j)
+    end do
+  end do
+
+! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
+! CHECK: {n = 2 : i64}
+
+! TODO: currently these trigger error in the parser
+! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
+! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
+end
+
+
+

Copy link

github-actions bot commented Feb 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

This patch introduces a new operation to represent the CUDA Fortran
kernel loop directive. This operation is modelled as a LoopLikeOp
operation in a similar way to acc.loop.

Lowering from the flang parse-tree to MLIR is also done.
Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@clementval clementval merged commit b3189b1 into llvm:main Feb 27, 2024
5 checks passed
@clementval clementval deleted the cuda_kernel branch February 27, 2024 19:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants