Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flang/include/flang/Parser/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);

const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);

const OmpObjectList *GetOmpObjectList(const OmpClause &clause);

template <typename T>
Expand Down
7 changes: 7 additions & 0 deletions flang/include/flang/Semantics/openmp-directive-sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
Directive::OMPD_teams_loop,
Directive::OMPD_fuse,
Directive::OMPD_tile,
Directive::OMPD_unroll,
};

static const OmpDirectiveSet loopTransformationSet{
Directive::OMPD_tile,
Directive::OMPD_unroll,
Directive::OMPD_fuse,
};

static const OmpDirectiveSet nonPartialVarSet{
Directive::OMPD_allocate,
Directive::OMPD_allocators,
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ bool ClauseProcessor::processCollapse(
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {

int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
eval.getFirstNestedEvaluation(),
clauses, loopResult, iv);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,10 @@ Link make(const parser::OmpClause::Link &inp,

LoopRange make(const parser::OmpClause::Looprange &inp,
semantics::SemanticsContext &semaCtx) {
llvm_unreachable("Unimplemented: looprange");
auto &t0 = std::get<0>(inp.v.t);
auto &t1 = std::get<1>(inp.v.t);
return LoopRange{{/*First*/ makeExpr(t0, semaCtx),
/*Count*/ makeExpr(t1, semaCtx)}};
}

Map make(const parser::OmpClause::Map &inp,
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
mlir::omp::LoopRelatedClauseOps result;
llvm::SmallVector<const semantics::Symbol *> iv;
collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
clauses, result, iv);
eval.getFirstNestedEvaluation(), clauses, result,
iv);

// Update the original variable just before exiting the worksharing
// loop. Conversion as follows:
Expand Down
106 changes: 83 additions & 23 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1982,17 +1982,18 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
static void genCanonicalLoopNest(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, size_t numLoops,
llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
lower::pft::Evaluation &nestedEval, mlir::Location loc,
const ConstructQueue &queue, ConstructQueue::const_iterator item,
size_t numLoops, llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
assert(loops.empty() && "Expecting empty list to fill");
assert(numLoops >= 1 && "Expecting at least one loop");

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

mlir::omp::LoopRelatedClauseOps loopInfo;
llvm::SmallVector<const semantics::Symbol *, 3> ivs;
collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
collectLoopRelatedInfo(converter, loc, eval, nestedEval, numLoops, loopInfo,
ivs);
assert(ivs.size() == numLoops &&
"Expected to parse as many loop variables as there are loops");

Expand All @@ -2014,7 +2015,7 @@ static void genCanonicalLoopNest(

// Step 1: Loop prologues
// Computing the trip count must happen before entering the outermost loop
lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
lower::pft::Evaluation *innermostEval = &nestedEval;
for ([[maybe_unused]] auto iv : ivs) {
if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
// OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
Expand Down Expand Up @@ -2186,7 +2187,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
canonLoops.reserve(numLoops);

genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
genCanonicalLoopNest(converter, symTable, semaCtx, eval,
eval.getFirstNestedEvaluation(), loc, queue, item,
numLoops, canonLoops);
assert((canonLoops.size() == numLoops) &&
"Expecting the predetermined number of loops");
Expand Down Expand Up @@ -2217,6 +2219,58 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
sizesClause.sizes);
}

static void genFuseOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

int32_t first = 0;
int32_t count = 0;
auto iter = llvm::find_if(item->clauses, [](const Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_looprange;
});
if (iter != item->clauses.end()) {
const auto &looprange = std::get<clause::LoopRange>(iter->u);
first = evaluate::ToInt64(std::get<0>(looprange.t)).value();
count = evaluate::ToInt64(std::get<1>(looprange.t)).value();
}

llvm::SmallVector<mlir::Value> applyees;
for (auto &child : eval.getNestedEvaluations()) {
// Skip OmpEndLoopDirective
if (&child == &eval.getLastNestedEvaluation())
break;

// Emit the associated loop
llvm::SmallVector<mlir::omp::CanonicalLoopOp> canonLoops;
genCanonicalLoopNest(converter, symTable, semaCtx, eval, child, loc, queue,
item, 1, canonLoops);

auto cli = llvm::getSingleElement(canonLoops).getCli();
applyees.push_back(cli);
}
// One generated loop + one for each loop not inside the specified looprange
// if present
llvm::SmallVector<mlir::Value> generatees;
int64_t numGeneratees = count == 0 ? 1 : applyees.size() - count + 1;
for (int i = 0; i < numGeneratees; i++) {
auto fusedCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc);
generatees.push_back(fusedCLI);
}
auto op = mlir::omp::FuseOp::create(firOpBuilder, loc, generatees, applyees);

if (count != 0) {
mlir::IntegerAttr firstAttr = firOpBuilder.getI32IntegerAttr(first);
mlir::IntegerAttr countAttr = firOpBuilder.getI32IntegerAttr(count);
op->setAttr("first", firstAttr);
op->setAttr("count", countAttr);
}
}

static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
Expand All @@ -2233,7 +2287,8 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,

// Emit the associated loop
llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
genCanonicalLoopNest(converter, symTable, semaCtx, eval,
eval.getFirstNestedEvaluation(), loc, queue, item, 1,
canonLoops);

llvm::SmallVector<mlir::Value, 1> applyees;
Expand Down Expand Up @@ -3507,6 +3562,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_tile:
genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_fuse:
genFuseOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
Expand Down Expand Up @@ -3962,22 +4020,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,

mlir::Location currentLocation = converter.genLocation(beginSpec.source);

if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
loopConstruct.GetNestedConstruct()) {
llvm::omp::Directive nestedDirective =
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Skip OMPD_tile since the tile sizes will be retrieved when
// generating the omp.loop_nest op.
break;
default: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
TODO(currentLocation,
"Applying a loop-associated on the loop generated by the " +
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
" construct");
}
for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
parser::omp::GetOmpLoop(construct)) {
llvm::omp::Directive nestedDirective =
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Skip OMPD_tile since the tile sizes will be retrieved when
// generating the omp.loop_nest op.
break;
default: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
TODO(currentLocation,
"Applying a loop-associated on the loop generated by the " +
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
" construct");
}
}
}
}

Expand Down
28 changes: 17 additions & 11 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,13 +812,14 @@ void collectTileSizesFromOpenMPConstruct(

int64_t collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
int64_t numCollapse = 1;

// Collect the loops to collapse.
lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
lower::pft::Evaluation *doConstructEval = &nestedEval;
if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
}
Expand All @@ -830,32 +831,37 @@ int64_t collectLoopRelatedInfo(
numCollapse = collapseValue;
}

collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
iv);
collectLoopRelatedInfo(converter, currentLocation, eval, nestedEval,
numCollapse, result, iv);
return numCollapse;
}

void collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, int64_t numCollapse,
mlir::omp::LoopRelatedClauseOps &result,
lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
int64_t numCollapse, mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

// Collect the loops to collapse.
lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
lower::pft::Evaluation *doConstructEval = &nestedEval;
if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
}

// Collect sizes from tile directive if present.
std::int64_t sizesLengthValue = 0l;
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
processTileSizesFromOpenMPConstruct(
ompCons, [&](const parser::OmpClause::Sizes *tclause) {
sizesLengthValue = tclause->v.size();
});
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const parser::OmpDirectiveSpecification &beginSpec{ompLoop->BeginDir()};
if (beginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
processTileSizesFromOpenMPConstruct(
ompCons, [&](const parser::OmpClause::Sizes *tclause) {
sizesLengthValue = tclause->v.size();
});
}
}
}

std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
Expand Down
6 changes: 4 additions & 2 deletions flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,

int64_t collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);

void collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, std::int64_t collapseValue,
lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
std::int64_t collapseValue,
// const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
unsigned(Directive::OMPD_teams_distribute_simd),
unsigned(Directive::OMPD_teams_loop),
unsigned(Directive::OMPD_fuse),
unsigned(Directive::OMPD_tile),
unsigned(Directive::OMPD_unroll),
};
Expand Down
17 changes: 17 additions & 0 deletions flang/lib/Parser/openmp-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
return nullptr;
}

const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
if (auto *construct{GetOmp(x)}) {
if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
return omp;
}
}
return nullptr;
}
const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
return &z->value();
}
}
return nullptr;
}

const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
// Clauses with OmpObjectList as its data member
using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
Expand Down
Loading