-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] CUF kernel loop directive #82836
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThis patch introduces a new operation to represent the CUDA Fortran kernel loop directive. This operation is modeled as a LoopLikeOp operation in a similar way to acc.loop. The CUFKernelDoConstruct parse tree node is also placed correctly in the PFTBuilder to be available in PFT evaluations. Lowering from the flang parse-tree to MLIR is also done. Full diff: https://github.com/llvm/llvm-project/pull/82836.diff 5 Files Affected:
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index c2b0fdbf357cde..9913f584133faa 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -138,7 +138,8 @@ using Directives =
std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
parser::OpenACCRoutineConstruct,
parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
- parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
+ parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective,
+ parser::CUFKernelDoConstruct>;
using DeclConstructs = std::tuple<parser::OpenMPDeclarativeConstruct,
parser::OpenACCDeclarativeConstruct>;
@@ -178,7 +179,7 @@ static constexpr bool isNopConstructStmt{common::HasMember<
template <typename A>
static constexpr bool isExecutableDirective{common::HasMember<
A, std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
- parser::OpenMPConstruct>>};
+ parser::OpenMPConstruct, parser::CUFKernelDoConstruct>>};
template <typename A>
static constexpr bool isFunctionLike{common::HasMember<
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 08239230f793f1..db5e5f4bc682e6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3127,4 +3127,31 @@ def fir_BoxOffsetOp : fir_Op<"box_offset", [NoMemoryEffect]> {
];
}
+def fir_CUDAKernelOp : fir_Op<"cuda_kernel", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+
+ let arguments = (ins
+ Variadic<I32>:$grid, // empty means `*`
+ Variadic<I32>:$block, // empty means `*`
+ Optional<I32>:$stream,
+ Variadic<Index>:$lowerbound,
+ Variadic<Index>:$upperbound,
+ Variadic<Index>:$step,
+ OptionalAttr<I64Attr>:$n
+ );
+
+ let regions = (region AnyRegion:$region);
+
+ let assemblyFormat = [{
+ `<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,`
+ custom<CUFKernelValues>($block, type($block))
+ ( `,` `stream` `=` $stream^ )? `>` `>` `>`
+ custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
+ $upperbound, type($upperbound), $step, type($step))
+ attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
#endif
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2d7f748cefa2d8..2c4825fafdbee4 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2453,6 +2453,127 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
}
+ void genFIR(const Fortran::parser::CUFKernelDoConstruct &kernel) {
+ localSymbols.pushScope();
+ const Fortran::parser::CUFKernelDoConstruct::Directive &dir =
+ std::get<Fortran::parser::CUFKernelDoConstruct::Directive>(kernel.t);
+
+ mlir::Location loc = genLocation(dir.source);
+
+ Fortran::lower::StatementContext stmtCtx;
+
+ unsigned nestedLoops = 1;
+
+ const auto &nLoops =
+ std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(dir.t);
+ if (nLoops)
+ nestedLoops = *Fortran::semantics::GetIntValue(*nLoops);
+
+ mlir::IntegerAttr n;
+ if (nestedLoops > 1)
+ n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
+
+ const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
+ const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+ const std::optional<Fortran::parser::ScalarIntExpr> &stream = std::get<3>(dir.t);
+
+ llvm::SmallVector<mlir::Value> gridValues;
+ for (const Fortran::parser::ScalarIntExpr &expr : grid)
+ gridValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ llvm::SmallVector<mlir::Value> blockValues;
+ for (const Fortran::parser::ScalarIntExpr &expr : block)
+ blockValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ mlir::Value streamValue;
+ if (stream)
+ streamValue = fir::getBase(genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
+
+ const auto &outerDoConstruct =
+ std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
+
+ llvm::SmallVector<mlir::Location> locs;
+ locs.push_back(loc);
+ llvm::SmallVector<mlir::Value> lbs, ubs, steps;
+
+ mlir::Type idxTy = builder->getIndexType();
+
+ llvm::SmallVector<mlir::Type> ivTypes;
+ llvm::SmallVector<mlir::Location> ivLocs;
+ llvm::SmallVector<mlir::Value> ivValues;
+ for (unsigned i = 0; i < nestedLoops; ++i) {
+ const Fortran::parser::LoopControl *loopControl;
+ Fortran::lower::pft::Evaluation *loopEval = &getEval().getFirstNestedEvaluation();
+
+ mlir::Location crtLoc = loc;
+ if (i == 0) {
+ loopControl = &*outerDoConstruct->GetLoopControl();
+ crtLoc = genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
+ } else {
+ auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
+ assert(doCons && "expect do construct");
+ loopControl = &*doCons->GetLoopControl();
+ crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+ }
+
+ locs.push_back(crtLoc);
+
+ const Fortran::parser::LoopControl::Bounds *bounds =
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds on the loop construct");
+
+ Fortran::semantics::Symbol &ivSym =
+ bounds->name.thing.symbol->GetUltimate();
+ ivValues.push_back(getSymbolAddress(ivSym));
+
+ lbs.push_back(builder->createConvert(crtLoc, idxTy,
+ fir::getBase(genExprValue(
+ *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
+ ubs.push_back(builder->createConvert(crtLoc, idxTy,
+ fir::getBase(genExprValue(
+ *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
+ if (bounds->step)
+ steps.push_back(fir::getBase(genExprValue(
+ *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+ else // If `step` is not present, assume it is `1`.
+ steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
+
+ ivTypes.push_back(idxTy);
+ ivLocs.push_back(crtLoc);
+ if (i < nestedLoops - 1)
+ loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+ }
+
+ auto op = builder->create<fir::CUDAKernelOp>(
+ loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n);
+ builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes, ivLocs);
+ mlir::Block &b = op.getRegion().back();
+ builder->setInsertionPointToStart(&b);
+
+ for (auto [arg, value] : llvm::zip(
+ op.getLoopRegions().front()->front().getArguments(), ivValues)) {
+ mlir::Value convArg = builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
+ builder->create<fir::StoreOp>(loc, convArg, value);
+ }
+
+ builder->create<fir::FirEndOp>(loc);
+ builder->setInsertionPointToStart(&b);
+
+ Fortran::lower::pft::Evaluation *crtEval = &getEval();
+ if (crtEval->lowerAsStructured()) {
+ crtEval = &crtEval->getFirstNestedEvaluation();
+ for (int64_t i = 1; i < nestedLoops; i++)
+ crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ }
+
+
+
+ // Generate loop body
+ for (Fortran::lower::pft::Evaluation &e : crtEval->getNestedEvaluations())
+ genFIR(e);
+
+ builder->setInsertionPointAfter(op);
+ localSymbols.popScope();
+ }
+
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
genOpenMPConstruct(*this, localSymbols, bridge.getSemanticsContext(),
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 0a534cdb3c4871..c2facb5a004a8f 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3866,6 +3866,104 @@ mlir::LogicalResult fir::DeclareOp::verify() {
return fortranVar.verifyDeclareLikeOpImpl(getMemref());
}
+llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() {
+ return {&getRegion()};
+}
+
+mlir::ParseResult
+parseCUFKernelValues(mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
+ llvm::SmallVectorImpl<mlir::Type> &types) {
+ if (mlir::succeeded(parser.parseOptionalStar()))
+ return mlir::success();
+
+ if (parser.parseOptionalLParen()) {
+ if (mlir::failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(values.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ })))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ } else {
+ if (parser.parseOperand(values.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ }
+ return mlir::success();
+}
+
+void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::ValueRange values, mlir::TypeRange types) {
+ if (values.empty())
+ p << "*";
+
+ if (values.size() > 1)
+ p << "(";
+ llvm::interleaveComma(values, p,
+ [&p](mlir::Value v) { p << v; });
+ if (values.size() > 1)
+ p << ")";
+}
+
+mlir::ParseResult
+parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region ®ion,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
+ llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
+ llvm::SmallVectorImpl<mlir::Type> &upperboundType,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step,
+ llvm::SmallVectorImpl<mlir::Type> &stepType) {
+
+ llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
+ if (parser.parseLParen() ||
+ parser.parseArgumentList(inductionVars, mlir::OpAsmParser::Delimiter::None,
+ /*allowType=*/true) ||
+ parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
+ parser.parseOperandList(lowerbound, inductionVars.size(),
+ mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
+ parser.parseKeyword("to") || parser.parseLParen() ||
+ parser.parseOperandList(upperbound, inductionVars.size(),
+ mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
+ parser.parseKeyword("step") || parser.parseLParen() ||
+ parser.parseOperandList(step, inductionVars.size(),
+ mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(stepType) || parser.parseRParen())
+ return mlir::failure();
+ return parser.parseRegion(region, inductionVars);
+}
+
+void printCUFKernelLoopControl(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::Region ®ion, mlir::ValueRange lowerbound,
+ mlir::TypeRange lowerboundType,
+ mlir::ValueRange upperbound,
+ mlir::TypeRange upperboundType, mlir::ValueRange steps,
+ mlir::TypeRange stepType) {
+ mlir::ValueRange regionArgs = region.front().getArguments();
+ if (!regionArgs.empty()) {
+ p << "(";
+ llvm::interleaveComma(regionArgs, p,
+ [&p](mlir::Value v) { p << v << " : " << v.getType(); });
+ p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
+ << upperbound << " : " << upperboundType << ") "
+ << " step (" << steps << " : " << stepType << ") ";
+ }
+ p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
+mlir::LogicalResult fir::CUDAKernelOp::verify() {
+ if (getLowerbound().size() != getUpperbound().size() ||
+ getLowerbound().size() != getStep().size())
+ return emitOpError(
+ "expect same number of values in lowerbound, upperbound and step");
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
new file mode 100644
index 00000000000000..db628fe756b952
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -0,0 +1,51 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Test lowering of CUDA kernel loop directive.
+
+subroutine sub1()
+ integer :: i, j
+ integer, parameter :: n = 100
+ real :: a(n), b(n)
+ real :: c(n,n), d(n,n)
+
+! CHECK-LABEL: func.func @_QPsub1()
+! CHECK: %[[IV:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+ !$cuf kernel do <<< 1, 2 >>>
+ do i = 1, n
+ a(i) = a(i) * b(i)
+ end do
+
+! CHECK: %[[LB:.*]] = fir.convert %c1{{.*}} : (i32) -> index
+! CHECK: %[[UB:.*]] = fir.convert %c100{{.*}} : (i32) -> index
+! CHECK: %[[STEP:.*]] = arith.constant 1 : index
+! CHECK: fir.cuda_kernel<<<%c1_i32, %c2_i32>>> (%[[ARG0:.*]] : index) = (%[[LB]] : index) to (%[[UB]] : index) step (%[[STEP]] : index)
+! CHECK-NOT: fir.do_loop
+! CHECK: %[[ARG0_I32:.*]] = fir.convert %[[ARG0]] : (index) -> i32
+! CHECK: fir.store %[[ARG0_I32]] to %[[IV]]#1 : !fir.ref<i32>
+
+
+ !$cuf kernel do <<< *, * >>>
+ do i = 1, n
+ a(i) = a(i) * b(i)
+ end do
+
+! CHECK: fir.cuda_kernel<<<*, *>>> (%{{.*}} : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index)
+
+ !$cuf kernel do(2) <<< 1, (256,1) >>>
+ do i = 1, n
+ do j = 1, n
+ c(i,j) = c(i,j) * d(i,j)
+ end do
+ end do
+
+! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
+! CHECK: {n = 2 : i64}
+
+! TODO: currently these trigger error in the parser
+! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
+! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
+end
+
+
+
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch introduces a new operation to represent the CUDA Fortran kernel loop directive. This operation is modelled as a LoopLikeOp operation in a similar way to acc.loop. Lowering from the flang parse-tree to MLIR is also done.
e1169e6
to
46b322d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch introduces a new operation to represent the CUDA Fortran kernel loop directive. This operation is modeled as a LoopLikeOp operation in a similar way to acc.loop.
The CUFKernelDoConstruct parse tree node is also placed correctly in the PFTBuilder to be available in PFT evaluations.
Lowering from the flang parse-tree to MLIR is also done.