-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Flang] Add standalone tile support #160298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/meinersbur/mlir_tile
Are you sure you want to change the base?
[Flang] Add standalone tile support #160298
Conversation
@llvm/pr-subscribers-flang-semantics @llvm/pr-subscribers-flang-fir-hlfir Author: Michael Kruse (Meinersbur) ChangesAdd support for the standalone OpenMP tile construct: !$omp tile sizes(...)
DO i = 1, 100
... This is complementary to #143715 which added support for the tile construct as part of another loop-associated construct such as worksharing-loop, distribute, etc. PR Stack:
Patch is 50.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160298.diff 25 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a96884f5680ba..55eda7e3404c1 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -431,6 +431,19 @@ bool ClauseProcessor::processNumTasks(
return false;
}
+bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
+ mlir::omp::SizesClauseOps &result) const {
+ if (auto *clause = findUniqueClause<omp::clause::Sizes>()) {
+ result.sizes.reserve(clause->v.size());
+ for (const ExprTy &vv : clause->v)
+ result.sizes.push_back(fir::getBase(converter.genExprValue(vv, stmtCtx)));
+
+ return true;
+ }
+
+ return false;
+}
+
bool ClauseProcessor::processNumTeams(
lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 324ea3c1047a5..9e352fa574a97 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -66,6 +66,8 @@ class ClauseProcessor {
mlir::omp::LoopRelatedClauseOps &loopResult,
mlir::omp::CollapseClauseOps &collapseResult,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
+ bool processSizes(StatementContext &stmtCtx,
+ mlir::omp::SizesClauseOps &result) const;
bool processDevice(lower::StatementContext &stmtCtx,
mlir::omp::DeviceClauseOps &result) const;
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5681be664d450..7812d9fe00be2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1984,125 +1984,241 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return loopOp;
}
-static mlir::omp::CanonicalLoopOp
-genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- lower::pft::Evaluation &eval, mlir::Location loc,
- const ConstructQueue &queue,
- ConstructQueue::const_iterator item,
- llvm::ArrayRef<const semantics::Symbol *> ivs,
- llvm::omp::Directive directive) {
+static void genCanonicalLoopNest(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item, size_t numLoops,
+ llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+ assert(loops.empty() && "Expecting empty list to fill");
+ assert(numLoops >= 1 && "Expecting at least one loop");
+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- assert(ivs.size() == 1 && "Nested loops not yet implemented");
- const semantics::Symbol *iv = ivs[0];
+ mlir::omp::LoopRelatedClauseOps loopInfo;
+ llvm::SmallVector<const semantics::Symbol *, 3> ivs;
+ collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+ assert(ivs.size() == numLoops &&
+ "Expected to parse as many loop variables as there are loops");
+
+ // Steps that follow:
+ // 1. Emit all of the loop's prologues (compute the tripcount)
+ // 2. Emit omp.canonical_loop nested inside each other (iteratively)
+ // 2.1. In the innermost omp.canonical_loop, emit the loop body prologue (in
+ // the body callback)
+ //
+ // Since emitting prologues and body code is split, remember prologue values
+ // for use when emitting the same loop's epilogues.
+ llvm::SmallVector<mlir::Value> tripcounts;
+ llvm::SmallVector<mlir::Value> clis;
+ llvm::SmallVector<lower::pft::Evaluation *> evals;
+ llvm::SmallVector<mlir::Type> loopVarTypes;
+ llvm::SmallVector<mlir::Value> loopStepVars;
+ llvm::SmallVector<mlir::Value> loopLBVars;
+ llvm::SmallVector<mlir::Value> blockArgs;
+
+ // Step 1: Loop prologues
+ // Computing the trip count must happen before entering the outermost loop
+ lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+ for ([[maybe_unused]] auto iv : ivs) {
+ if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+ // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
+ // Will need to add special cases for this combination.
+ TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ }
+
+ auto &doLoopEval = innermostEval->getFirstNestedEvaluation();
+ evals.push_back(innermostEval);
+
+ // Get the loop bounds (and increment)
+ // auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
+ auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ assert(loopControl.has_value());
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for canonical loop");
+ lower::StatementContext stmtCtx;
+ mlir::Value loopLBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
+ mlir::Value loopUBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
+ mlir::Value loopStepVar = [&]() {
+ if (bounds->step) {
+ return fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
+ }
- auto &nestedEval = eval.getFirstNestedEvaluation();
- if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
- // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will
- // need to add special cases for this combination.
- TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ // If `step` is not present, assume it is `1`.
+ auto intTy = firOpBuilder.getI32Type();
+ return firOpBuilder.createIntegerConstant(loc, intTy, 1);
+ }();
+
+ // Get the integer kind for the loop variable and cast the loop bounds
+ size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ loopVarTypes.push_back(loopVarType);
+ loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
+ loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
+ loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
+ loopLBVars.push_back(loopLBVar);
+ loopStepVars.push_back(loopStepVar);
+
+ // Start lowering
+ mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
+ mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
+ mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
+
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
+ mlir::Value negStep =
+ firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
+ mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, negStep, loopStepVar);
+ mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopUBVar, loopLBVar);
+ mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopLBVar, loopUBVar);
+
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
+ // is non-negative and we can use unsigned arithmetic.
+ mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
+ mlir::Value tcMinusOne =
+ firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
+ mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
+
+ // Fall back to 0 if lb > ub
+ mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
+ mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isZeroTC, zero, tcIfLooping);
+ tripcounts.push_back(tripcount);
+
+ // Create the CLI handle.
+ auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ mlir::Value cli = newcli.getResult();
+ clis.push_back(cli);
+
+ innermostEval = &*std::next(innermostEval->getNestedEvaluations().begin());
}
- // Get the loop bounds (and increment)
- auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
- auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
- assert(doStmt && "Expected do loop to be in the nested evaluation");
- auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
- assert(loopControl.has_value());
- auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
- assert(bounds && "Expected bounds for canonical loop");
- lower::StatementContext stmtCtx;
- mlir::Value loopLBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
- mlir::Value loopUBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
- mlir::Value loopStepVar = [&]() {
- if (bounds->step) {
- return fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
- }
+ // Step 2: Create nested canoncial loops
+ for (auto i : llvm::seq<size_t>(numLoops)) {
+ bool isInnermost = (i == numLoops - 1);
+ mlir::Type loopVarType = loopVarTypes[i];
+ mlir::Value tripcount = tripcounts[i];
+ mlir::Value cli = clis[i];
+ auto &&eval = evals[i];
+
+ auto ivCallback = [&, i, isInnermost](mlir::Operation *op)
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
+ mlir::Region ®ion = op->getRegion(0);
+
+ // Create the op's region skeleton (BB taking the iv as argument)
+ firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc});
+ blockArgs.push_back(region.front().getArgument(0));
+
+ // Step 2.1: Emit body prologue code
+ // Compute the translation from logical iteration number to the value of
+ // the loop's iteration variable only in the innermost body. Currently,
+ // loop transformations do not allow any instruction between loops, but
+ // this will change with
+ if (isInnermost) {
+ assert(blockArgs.size() == numLoops &&
+ "Expecting all block args to have been collected by now");
+ for (auto j : llvm::seq<size_t>(numLoops)) {
+ mlir::Value natIterNum = fir::getBase(blockArgs[j]);
+ mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>(
+ loc, natIterNum, loopStepVars[j]);
+ mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, loopLBVars[j], scaled);
+
+ mlir::OpBuilder::InsertPoint insPt =
+ firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ mlir::Type tempTy = converter.genType(*ivs[j]);
+ firOpBuilder.restoreInsertionPoint(insPt);
+
+ // Write the loop value into loop variable
+ mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, userVal);
+ hlfir::Entity lhs{converter.getSymbolAddress(*ivs[j])};
+ lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs);
+ mlir::Operation *storeOp =
+ hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs);
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ }
+ }
- // If `step` is not present, assume it is `1`.
- return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
- 1);
- }();
+ return {ivs[i]};
+ };
- // Get the integer kind for the loop variable and cast the loop bounds
- size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
- loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
- loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
-
- // Start lowering
- mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
- mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
- mlir::Value isDownwards = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
-
- // Ensure we are counting upwards. If not, negate step and swap lb and ub.
- mlir::Value negStep =
- mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar);
- mlir::Value incr = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isDownwards, negStep, loopStepVar);
- mlir::Value lb = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopUBVar, loopLBVar);
- mlir::Value ub = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopLBVar, loopUBVar);
-
- // Compute the trip count assuming lb <= ub. This guarantees that the result
- // is non-negative and we can use unsigned arithmetic.
- mlir::Value span = mlir::arith::SubIOp::create(
- firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
- mlir::Value tcMinusOne =
- mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr);
- mlir::Value tcIfLooping =
- mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one,
- ::mlir::arith::IntegerOverflowFlags::nuw);
-
- // Fall back to 0 if lb > ub
- mlir::Value isZeroTC = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb);
- mlir::Value tripcount = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isZeroTC, zero, tcIfLooping);
-
- // Create the CLI handle.
- auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc);
- mlir::Value cli = newcli.getResult();
-
- auto ivCallback = [&](mlir::Operation *op)
- -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
- mlir::Region ®ion = op->getRegion(0);
-
- // Create the op's region skeleton (BB taking the iv as argument)
- firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc});
-
- // Compute the value of the loop variable from the logical iteration number.
- mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
- mlir::Value scaled =
- mlir::arith::MulIOp::create(firOpBuilder, loc, natIterNum, loopStepVar);
- mlir::Value userVal =
- mlir::arith::AddIOp::create(firOpBuilder, loc, loopLBVar, scaled);
-
- // Write loop value to loop variable
- mlir::Operation *storeOp = setLoopVar(converter, loc, userVal, iv);
-
- firOpBuilder.setInsertionPointAfter(storeOp);
- return {iv};
- };
+ // Create the omp.canonical_loop operation
+ auto opGenInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *eval,
+ llvm::omp::Directive::OMPD_unknown)
+ .setGenSkeletonOnly(!isInnermost)
+ .setClauses(&item->clauses)
+ .setPrivatize(false)
+ .setGenRegionEntryCb(ivCallback);
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
+ std::move(opGenInfo), queue, item, tripcount, cli);
+ loops.push_back(canonLoop);
+
+ // Insert next loop nested inside last loop
+ firOpBuilder.setInsertionPoint(
+ canonLoop.getRegion().back().getTerminator());
+ }
- // Create the omp.canonical_loop operation
- auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
- OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
- directive)
- .setClauses(&item->clauses)
- .setPrivatize(false)
- .setGenRegionEntryCb(ivCallback),
- queue, item, tripcount, cli);
+ firOpBuilder.setInsertionPointAfter(loops.front());
+}
+
+static void genTileOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ lower::StatementContext &stmtCtx,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- firOpBuilder.setInsertionPointAfter(canonLoop);
- return canonLoop;
+ mlir::omp::SizesClauseOps sizesClause;
+ ClauseProcessor cp(converter, semaCtx, item->clauses);
+ cp.processSizes(stmtCtx, sizesClause);
+
+ size_t numLoops = sizesClause.sizes.size();
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
+ canonLoops.reserve(numLoops);
+
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+ numLoops, canonLoops);
+ assert((canonLoops.size() == numLoops) &&
+ "Expecting the predetermined number of loops");
+
+ llvm::SmallVector<mlir::Value, 3> applyees;
+ applyees.reserve(numLoops);
+ for (mlir::omp::CanonicalLoopOp l : canonLoops)
+ applyees.push_back(l.getCli());
+
+ // Emit the associated loops and create a CLI for each affected loop
+ llvm::SmallVector<mlir::Value, 3> gridGeneratees;
+ llvm::SmallVector<mlir::Value, 3> intratileGeneratees;
+ gridGeneratees.reserve(numLoops);
+ intratileGeneratees.reserve(numLoops);
+ for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) {
+ auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ gridGeneratees.push_back(gridCLI.getResult());
+ auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ intratileGeneratees.push_back(intratileCLI.getResult());
+ }
+
+ llvm::SmallVector<mlir::Value, 6> generatees;
+ generatees.reserve(2 * numLoops);
+ generatees.append(gridGeneratees);
+ generatees.append(intratileGeneratees);
+
+ firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees,
+ sizesClause.sizes);
}
static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
@@ -2114,22 +2230,22 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::LoopRelatedClauseOps loopInfo;
- llvm::SmallVector<const semantics::Symbol *> iv;
- collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);
-
// Clauses for unrolling not yet implemnted
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processTODO<clause::Partial, clause::Full>(
loc, llvm::omp::Directive::OMPD_unroll);
// Emit the associated loop
- auto canonLoop =
- genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item,
- iv, llvm::omp::Directive::OMPD_unroll);
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+ canonLoops);
+
+ llvm::SmallVector<mlir::Value, 1> applyees;
+ for (auto &&canonLoop : canonLoops)
+ applyees.push_back(canonLoop.getCli());
// Apply unrolling to it
- auto cli = canonLoop.getCli();
+ auto cli = llvm::getSingleElement(canonLoops).getCli();
mlir::omp::UnrollHeuristicOp::create(firOpBuilder, loc, cli);
}
@@ -3362,13 +3478,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
item);
break;
- case llvm::omp::Directive::OMPD_tile: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- if (!semaCtx.langOptions().OpenMPSimd)
- T...
[truncated]
|
@llvm/pr-subscribers-flang-parser Author: Michael Kruse (Meinersbur) ChangesAdd support for the standalone OpenMP tile construct: !$omp tile sizes(...)
DO i = 1, 100
... This is complementary to #143715 which added support for the tile construct as part of another loop-associated construct such as worksharing-loop, distribute, etc. PR Stack:
Patch is 50.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160298.diff 25 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a96884f5680ba..55eda7e3404c1 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -431,6 +431,19 @@ bool ClauseProcessor::processNumTasks(
return false;
}
+bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
+ mlir::omp::SizesClauseOps &result) const {
+ if (auto *clause = findUniqueClause<omp::clause::Sizes>()) {
+ result.sizes.reserve(clause->v.size());
+ for (const ExprTy &vv : clause->v)
+ result.sizes.push_back(fir::getBase(converter.genExprValue(vv, stmtCtx)));
+
+ return true;
+ }
+
+ return false;
+}
+
bool ClauseProcessor::processNumTeams(
lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 324ea3c1047a5..9e352fa574a97 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -66,6 +66,8 @@ class ClauseProcessor {
mlir::omp::LoopRelatedClauseOps &loopResult,
mlir::omp::CollapseClauseOps &collapseResult,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
+ bool processSizes(StatementContext &stmtCtx,
+ mlir::omp::SizesClauseOps &result) const;
bool processDevice(lower::StatementContext &stmtCtx,
mlir::omp::DeviceClauseOps &result) const;
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5681be664d450..7812d9fe00be2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1984,125 +1984,241 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return loopOp;
}
-static mlir::omp::CanonicalLoopOp
-genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- lower::pft::Evaluation &eval, mlir::Location loc,
- const ConstructQueue &queue,
- ConstructQueue::const_iterator item,
- llvm::ArrayRef<const semantics::Symbol *> ivs,
- llvm::omp::Directive directive) {
+static void genCanonicalLoopNest(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item, size_t numLoops,
+ llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+ assert(loops.empty() && "Expecting empty list to fill");
+ assert(numLoops >= 1 && "Expecting at least one loop");
+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- assert(ivs.size() == 1 && "Nested loops not yet implemented");
- const semantics::Symbol *iv = ivs[0];
+ mlir::omp::LoopRelatedClauseOps loopInfo;
+ llvm::SmallVector<const semantics::Symbol *, 3> ivs;
+ collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+ assert(ivs.size() == numLoops &&
+ "Expected to parse as many loop variables as there are loops");
+
+ // Steps that follow:
+ // 1. Emit all of the loop's prologues (compute the tripcount)
+ // 2. Emit omp.canonical_loop nested inside each other (iteratively)
+ // 2.1. In the innermost omp.canonical_loop, emit the loop body prologue (in
+ // the body callback)
+ //
+ // Since emitting prologues and body code is split, remember prologue values
+ // for use when emitting the same loop's epilogues.
+ llvm::SmallVector<mlir::Value> tripcounts;
+ llvm::SmallVector<mlir::Value> clis;
+ llvm::SmallVector<lower::pft::Evaluation *> evals;
+ llvm::SmallVector<mlir::Type> loopVarTypes;
+ llvm::SmallVector<mlir::Value> loopStepVars;
+ llvm::SmallVector<mlir::Value> loopLBVars;
+ llvm::SmallVector<mlir::Value> blockArgs;
+
+ // Step 1: Loop prologues
+ // Computing the trip count must happen before entering the outermost loop
+ lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+ for ([[maybe_unused]] auto iv : ivs) {
+ if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+ // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
+ // Will need to add special cases for this combination.
+ TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ }
+
+ auto &doLoopEval = innermostEval->getFirstNestedEvaluation();
+ evals.push_back(innermostEval);
+
+ // Get the loop bounds (and increment)
+ // auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
+ auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ assert(loopControl.has_value());
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for canonical loop");
+ lower::StatementContext stmtCtx;
+ mlir::Value loopLBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
+ mlir::Value loopUBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
+ mlir::Value loopStepVar = [&]() {
+ if (bounds->step) {
+ return fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
+ }
- auto &nestedEval = eval.getFirstNestedEvaluation();
- if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
- // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will
- // need to add special cases for this combination.
- TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ // If `step` is not present, assume it is `1`.
+ auto intTy = firOpBuilder.getI32Type();
+ return firOpBuilder.createIntegerConstant(loc, intTy, 1);
+ }();
+
+ // Get the integer kind for the loop variable and cast the loop bounds
+ size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ loopVarTypes.push_back(loopVarType);
+ loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
+ loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
+ loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
+ loopLBVars.push_back(loopLBVar);
+ loopStepVars.push_back(loopStepVar);
+
+ // Start lowering
+ mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
+ mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
+ mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
+
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
+ mlir::Value negStep =
+ firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
+ mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, negStep, loopStepVar);
+ mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopUBVar, loopLBVar);
+ mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopLBVar, loopUBVar);
+
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
+ // is non-negative and we can use unsigned arithmetic.
+ mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
+ mlir::Value tcMinusOne =
+ firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
+ mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
+
+ // Fall back to 0 if lb > ub
+ mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
+ mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isZeroTC, zero, tcIfLooping);
+ tripcounts.push_back(tripcount);
+
+ // Create the CLI handle.
+ auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ mlir::Value cli = newcli.getResult();
+ clis.push_back(cli);
+
+ innermostEval = &*std::next(innermostEval->getNestedEvaluations().begin());
}
- // Get the loop bounds (and increment)
- auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
- auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
- assert(doStmt && "Expected do loop to be in the nested evaluation");
- auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
- assert(loopControl.has_value());
- auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
- assert(bounds && "Expected bounds for canonical loop");
- lower::StatementContext stmtCtx;
- mlir::Value loopLBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
- mlir::Value loopUBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
- mlir::Value loopStepVar = [&]() {
- if (bounds->step) {
- return fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
- }
+ // Step 2: Create nested canoncial loops
+ for (auto i : llvm::seq<size_t>(numLoops)) {
+ bool isInnermost = (i == numLoops - 1);
+ mlir::Type loopVarType = loopVarTypes[i];
+ mlir::Value tripcount = tripcounts[i];
+ mlir::Value cli = clis[i];
+ auto &&eval = evals[i];
+
+ auto ivCallback = [&, i, isInnermost](mlir::Operation *op)
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
+ mlir::Region ®ion = op->getRegion(0);
+
+ // Create the op's region skeleton (BB taking the iv as argument)
+ firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc});
+ blockArgs.push_back(region.front().getArgument(0));
+
+ // Step 2.1: Emit body prologue code
+ // Compute the translation from logical iteration number to the value of
+ // the loop's iteration variable only in the innermost body. Currently,
+ // loop transformations do not allow any instruction between loops, but
+ // this will change with
+ if (isInnermost) {
+ assert(blockArgs.size() == numLoops &&
+ "Expecting all block args to have been collected by now");
+ for (auto j : llvm::seq<size_t>(numLoops)) {
+ mlir::Value natIterNum = fir::getBase(blockArgs[j]);
+ mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>(
+ loc, natIterNum, loopStepVars[j]);
+ mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, loopLBVars[j], scaled);
+
+ mlir::OpBuilder::InsertPoint insPt =
+ firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ mlir::Type tempTy = converter.genType(*ivs[j]);
+ firOpBuilder.restoreInsertionPoint(insPt);
+
+ // Write the loop value into loop variable
+ mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, userVal);
+ hlfir::Entity lhs{converter.getSymbolAddress(*ivs[j])};
+ lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs);
+ mlir::Operation *storeOp =
+ hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs);
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ }
+ }
- // If `step` is not present, assume it is `1`.
- return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
- 1);
- }();
+ return {ivs[i]};
+ };
- // Get the integer kind for the loop variable and cast the loop bounds
- size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
- loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
- loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
-
- // Start lowering
- mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
- mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
- mlir::Value isDownwards = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
-
- // Ensure we are counting upwards. If not, negate step and swap lb and ub.
- mlir::Value negStep =
- mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar);
- mlir::Value incr = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isDownwards, negStep, loopStepVar);
- mlir::Value lb = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopUBVar, loopLBVar);
- mlir::Value ub = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopLBVar, loopUBVar);
-
- // Compute the trip count assuming lb <= ub. This guarantees that the result
- // is non-negative and we can use unsigned arithmetic.
- mlir::Value span = mlir::arith::SubIOp::create(
- firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
- mlir::Value tcMinusOne =
- mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr);
- mlir::Value tcIfLooping =
- mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one,
- ::mlir::arith::IntegerOverflowFlags::nuw);
-
- // Fall back to 0 if lb > ub
- mlir::Value isZeroTC = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb);
- mlir::Value tripcount = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isZeroTC, zero, tcIfLooping);
-
- // Create the CLI handle.
- auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc);
- mlir::Value cli = newcli.getResult();
-
- auto ivCallback = [&](mlir::Operation *op)
- -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
- mlir::Region ®ion = op->getRegion(0);
-
- // Create the op's region skeleton (BB taking the iv as argument)
- firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc});
-
- // Compute the value of the loop variable from the logical iteration number.
- mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
- mlir::Value scaled =
- mlir::arith::MulIOp::create(firOpBuilder, loc, natIterNum, loopStepVar);
- mlir::Value userVal =
- mlir::arith::AddIOp::create(firOpBuilder, loc, loopLBVar, scaled);
-
- // Write loop value to loop variable
- mlir::Operation *storeOp = setLoopVar(converter, loc, userVal, iv);
-
- firOpBuilder.setInsertionPointAfter(storeOp);
- return {iv};
- };
+ // Create the omp.canonical_loop operation
+ auto opGenInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *eval,
+ llvm::omp::Directive::OMPD_unknown)
+ .setGenSkeletonOnly(!isInnermost)
+ .setClauses(&item->clauses)
+ .setPrivatize(false)
+ .setGenRegionEntryCb(ivCallback);
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
+ std::move(opGenInfo), queue, item, tripcount, cli);
+ loops.push_back(canonLoop);
+
+ // Insert next loop nested inside last loop
+ firOpBuilder.setInsertionPoint(
+ canonLoop.getRegion().back().getTerminator());
+ }
- // Create the omp.canonical_loop operation
- auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
- OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
- directive)
- .setClauses(&item->clauses)
- .setPrivatize(false)
- .setGenRegionEntryCb(ivCallback),
- queue, item, tripcount, cli);
+ firOpBuilder.setInsertionPointAfter(loops.front());
+}
+
+static void genTileOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ lower::StatementContext &stmtCtx,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- firOpBuilder.setInsertionPointAfter(canonLoop);
- return canonLoop;
+ mlir::omp::SizesClauseOps sizesClause;
+ ClauseProcessor cp(converter, semaCtx, item->clauses);
+ cp.processSizes(stmtCtx, sizesClause);
+
+ size_t numLoops = sizesClause.sizes.size();
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
+ canonLoops.reserve(numLoops);
+
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+ numLoops, canonLoops);
+ assert((canonLoops.size() == numLoops) &&
+ "Expecting the predetermined number of loops");
+
+ llvm::SmallVector<mlir::Value, 3> applyees;
+ applyees.reserve(numLoops);
+ for (mlir::omp::CanonicalLoopOp l : canonLoops)
+ applyees.push_back(l.getCli());
+
+ // Emit the associated loops and create a CLI for each affected loop
+ llvm::SmallVector<mlir::Value, 3> gridGeneratees;
+ llvm::SmallVector<mlir::Value, 3> intratileGeneratees;
+ gridGeneratees.reserve(numLoops);
+ intratileGeneratees.reserve(numLoops);
+ for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) {
+ auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ gridGeneratees.push_back(gridCLI.getResult());
+ auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ intratileGeneratees.push_back(intratileCLI.getResult());
+ }
+
+ llvm::SmallVector<mlir::Value, 6> generatees;
+ generatees.reserve(2 * numLoops);
+ generatees.append(gridGeneratees);
+ generatees.append(intratileGeneratees);
+
+ firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees,
+ sizesClause.sizes);
}
static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
@@ -2114,22 +2230,22 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::LoopRelatedClauseOps loopInfo;
- llvm::SmallVector<const semantics::Symbol *> iv;
- collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);
-
// Clauses for unrolling not yet implemnted
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processTODO<clause::Partial, clause::Full>(
loc, llvm::omp::Directive::OMPD_unroll);
// Emit the associated loop
- auto canonLoop =
- genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item,
- iv, llvm::omp::Directive::OMPD_unroll);
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+ canonLoops);
+
+ llvm::SmallVector<mlir::Value, 1> applyees;
+ for (auto &&canonLoop : canonLoops)
+ applyees.push_back(canonLoop.getCli());
// Apply unrolling to it
- auto cli = canonLoop.getCli();
+ auto cli = llvm::getSingleElement(canonLoops).getCli();
mlir::omp::UnrollHeuristicOp::create(firOpBuilder, loc, cli);
}
@@ -3362,13 +3478,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
item);
break;
- case llvm::omp::Directive::OMPD_tile: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- if (!semaCtx.langOptions().OpenMPSimd)
- T...
[truncated]
|
let allowedOnceClauses = [ | ||
VersionedClause<OMPC_Sizes, 51>, | ||
]; | ||
let requiredClauses = [ | ||
VersionedClause<OMPC_Sizes, 51>, | ||
]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sizes
is required exactly once. Listing it in "allowedOnceClauses" and "requiredClauses" at the same time does work. The alternative would to add a new "requiredOnce" field just for sizes
, would be considerably more work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Add support for the standalone OpenMP tile construct:
This is complementary to #143715 which added support for the tile construct as part of another loop-associated construct such as worksharing-loop, distribute, etc.
PR Stack: