From e5dca46bb3de197b3341e0c364a76287bcc6019c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 16 Jan 2024 16:40:47 -0600 Subject: [PATCH 1/4] [Flang][OpenMP] TableGen support for getting leaf constructs Implement getLeafConstructs(D), which for a composite directive D will return the list of the constituent leaf directives. --- .../llvm/Frontend/Directive/DirectiveBase.td | 3 + llvm/include/llvm/Frontend/OpenMP/OMP.td | 48 +++++++++++ llvm/include/llvm/TableGen/DirectiveEmitter.h | 4 + llvm/utils/TableGen/DirectiveEmitter.cpp | 81 +++++++++++++++++++ 4 files changed, 136 insertions(+) diff --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td index 31578710365b21..139c794cd49856 100644 --- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td +++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td @@ -152,6 +152,9 @@ class Directive { // List of clauses that are required. list requiredClauses = []; + // List of names of leaf constituent directives. + list leafs = []; + // Set directive used by default when unknown. bit isDefault = false; } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 2388abac81ceb4..d51f471466669f 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -771,6 +771,7 @@ def OMP_TargetParallel : Directive<"target parallel"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel"]; } def OMP_TargetParallelFor : Directive<"target parallel for"> { let allowedClauses = [ @@ -803,6 +804,7 @@ def OMP_TargetParallelFor : Directive<"target parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel", "for"]; } def OMP_TargetParallelDo : Directive<"target parallel do"> { let allowedClauses = [ @@ -833,6 +835,7 @@ def OMP_TargetParallelDo : Directive<"target parallel do"> { VersionedClause, VersionedClause ]; + let leafs = ["target", "parallel", "do"]; } def OMP_TargetUpdate : Directive<"target update"> { let allowedClauses = [ @@ -866,6 +869,7 @@ def OMP_ParallelFor : Directive<"parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "for"]; } def OMP_ParallelDo : Directive<"parallel do"> { let allowedClauses = [ @@ -887,6 +891,7 @@ def OMP_ParallelDo : Directive<"parallel do"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "do"]; } def OMP_ParallelForSimd : Directive<"parallel for simd"> { let allowedClauses = [ @@ -912,6 +917,7 @@ def OMP_ParallelForSimd : Directive<"parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "for", "simd"]; } def OMP_ParallelDoSimd : Directive<"parallel do simd"> { let allowedClauses = [ @@ -938,6 +944,7 @@ def OMP_ParallelDoSimd : Directive<"parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "do", "simd"]; } def OMP_ParallelMaster : Directive<"parallel master"> { let allowedClauses = [ @@ -953,6 +960,7 @@ def OMP_ParallelMaster : Directive<"parallel master"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "master"]; } def OMP_ParallelMasked : Directive<"parallel masked"> { let allowedClauses = [ @@ -969,6 +977,7 @@ def OMP_ParallelMasked : Directive<"parallel masked"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "masked"]; } def OMP_ParallelSections : Directive<"parallel sections"> { let allowedClauses = [ @@ -987,6 +996,7 @@ def OMP_ParallelSections : Directive<"parallel sections"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "sections"]; } def OMP_ForSimd : Directive<"for simd"> { let allowedClauses = [ @@ -1007,6 +1017,7 @@ def OMP_ForSimd : Directive<"for simd"> { VersionedClause, VersionedClause ]; + let leafs = ["for", "simd"]; } def OMP_DoSimd : Directive<"do simd"> { let allowedClauses = [ @@ -1027,6 +1038,7 @@ def OMP_DoSimd : Directive<"do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["do", "simd"]; } def OMP_CancellationPoint : Directive<"cancellation point"> {} def OMP_DeclareReduction : Directive<"declare reduction"> {} @@ -1104,6 +1116,7 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = ["taskloop", "simd"]; } def OMP_Distribute : Directive<"distribute"> { let allowedClauses = [ @@ -1156,6 +1169,7 @@ def OMP_DistributeParallelFor : Directive<"distribute parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = ["distribute", "parallel", "for"]; } def OMP_DistributeParallelDo : Directive<"distribute parallel do"> { let allowedClauses = [ @@ -1179,6 +1193,7 @@ def OMP_DistributeParallelDo : Directive<"distribute parallel do"> { VersionedClause, VersionedClause ]; + let leafs = ["distribute", "parallel", "do"]; } def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> { let allowedClauses = [ @@ -1204,6 +1219,7 @@ def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["distribute", "parallel", "for", "simd"]; } def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> { let allowedClauses = [ @@ -1228,6 +1244,7 @@ def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["distribute", "parallel", "do", "simd"]; } def OMP_DistributeSimd : Directive<"distribute simd"> { let allowedClauses = [ @@ -1254,6 +1271,7 @@ def OMP_DistributeSimd : Directive<"distribute simd"> { VersionedClause, VersionedClause ]; + let leafs = ["distribute", "simd"]; } def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> { @@ -1291,6 +1309,7 @@ def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel", "for", "simd"]; } def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> { let allowedClauses = [ @@ -1322,6 +1341,7 @@ def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["target", "parallel", "do", "simd"]; } def OMP_TargetSimd : Directive<"target simd"> { let allowedClauses = [ @@ -1356,6 +1376,7 @@ def OMP_TargetSimd : Directive<"target simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "simd"]; } def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedClauses = [ @@ -1375,6 +1396,7 @@ def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedOnceClauses = [ VersionedClause ]; + let leafs = ["teams", "distribute"]; } def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> { let allowedClauses = [ @@ -1400,6 +1422,7 @@ def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> { VersionedClause, VersionedClause ]; + let leafs = ["teams", "distribute", "simd"]; } def OMP_TeamsDistributeParallelForSimd : @@ -1428,6 +1451,7 @@ def OMP_TeamsDistributeParallelForSimd : VersionedClause, VersionedClause, ]; + let leafs = ["teams", "distribute", "parallel", "for", "simd"]; } def OMP_TeamsDistributeParallelDoSimd : Directive<"teams distribute parallel do simd"> { @@ -1456,6 +1480,7 @@ def OMP_TeamsDistributeParallelDoSimd : VersionedClause, VersionedClause ]; + let leafs = ["teams", "distribute", "parallel", "do", "simd"]; } def OMP_TeamsDistributeParallelFor : Directive<"teams distribute parallel for"> { @@ -1479,6 +1504,7 @@ def OMP_TeamsDistributeParallelFor : VersionedClause, VersionedClause, ]; + let leafs = ["teams", "distribute", "parallel", "for"]; } def OMP_TeamsDistributeParallelDo : Directive<"teams distribute parallel do"> { @@ -1505,6 +1531,7 @@ let allowedOnceClauses = [ VersionedClause, VersionedClause ]; + let leafs = ["teams", "distribute", "parallel", "do"]; } def OMP_TargetTeams : Directive<"target teams"> { let allowedClauses = [ @@ -1532,6 +1559,7 @@ def OMP_TargetTeams : Directive<"target teams"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "teams"]; } def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { let allowedClauses = [ @@ -1560,6 +1588,7 @@ def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "teams", "distribute"]; } def OMP_TargetTeamsDistributeParallelFor : @@ -1594,6 +1623,7 @@ def OMP_TargetTeamsDistributeParallelFor : let allowedOnceClauses = [ VersionedClause, ]; + let leafs = ["target", "teams", "distribute", "parallel", "for"]; } def OMP_TargetTeamsDistributeParallelDo : Directive<"target teams distribute parallel do"> { @@ -1628,6 +1658,7 @@ def OMP_TargetTeamsDistributeParallelDo : VersionedClause, VersionedClause ]; + let leafs = ["target", "teams", "distribute", "parallel", "do"]; } def OMP_TargetTeamsDistributeParallelForSimd : Directive<"target teams distribute parallel for simd"> { @@ -1666,6 +1697,7 @@ def OMP_TargetTeamsDistributeParallelForSimd : let allowedOnceClauses = [ VersionedClause, ]; + let leafs = ["target", "teams", "distribute", "parallel", "for", "simd"]; } def OMP_TargetTeamsDistributeParallelDoSimd : Directive<"target teams distribute parallel do simd"> { @@ -1704,6 +1736,7 @@ def OMP_TargetTeamsDistributeParallelDoSimd : VersionedClause, VersionedClause ]; + let leafs = ["target", "teams", "distribute", "parallel", "do", "simd"]; } def OMP_TargetTeamsDistributeSimd : Directive<"target teams distribute simd"> { @@ -1738,6 +1771,7 @@ def OMP_TargetTeamsDistributeSimd : VersionedClause, VersionedClause ]; + let leafs = ["target", "teams", "distribute", "simd"]; } def OMP_Allocate : Directive<"allocate"> { let allowedOnceClauses = [ @@ -1779,6 +1813,7 @@ def OMP_MasterTaskloop : Directive<"master taskloop"> { VersionedClause, VersionedClause ]; + let leafs = ["master", "taskloop"]; } def OMP_MaskedTaskloop : Directive<"masked taskloop"> { let allowedClauses = [ @@ -1801,6 +1836,7 @@ def OMP_MaskedTaskloop : Directive<"masked taskloop"> { VersionedClause, VersionedClause ]; + let leafs = ["masked", "taskloop"]; } def OMP_ParallelMasterTaskloop : Directive<"parallel master taskloop"> { @@ -1826,6 +1862,7 @@ def OMP_ParallelMasterTaskloop : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "master", "taskloop"]; } def OMP_ParallelMaskedTaskloop : Directive<"parallel masked taskloop"> { @@ -1852,6 +1889,7 @@ def OMP_ParallelMaskedTaskloop : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "masked", "taskloop"]; } def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> { let allowedClauses = [ @@ -1879,6 +1917,7 @@ def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = ["master", "taskloop", "simd"]; } def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> { let allowedClauses = [ @@ -1907,6 +1946,7 @@ def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = ["masked", "taskloop", "simd"]; } def OMP_ParallelMasterTaskloopSimd : Directive<"parallel master taskloop simd"> { @@ -1938,6 +1978,7 @@ def OMP_ParallelMasterTaskloopSimd : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "master", "taskloop", "simd"]; } def OMP_ParallelMaskedTaskloopSimd : Directive<"parallel masked taskloop simd"> { @@ -1970,6 +2011,7 @@ def OMP_ParallelMaskedTaskloopSimd : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "masked", "taskloop", "simd"]; } def OMP_Depobj : Directive<"depobj"> { let allowedClauses = [ @@ -2016,6 +2058,7 @@ def OMP_ParallelWorkshare : Directive<"parallel workshare"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "workshare"]; } def OMP_Workshare : Directive<"workshare"> {} def OMP_EndDo : Directive<"end do"> { @@ -2102,6 +2145,7 @@ def OMP_teams_loop : Directive<"teams loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["teams", "loop"]; } def OMP_target_teams_loop : Directive<"target teams loop"> { let allowedClauses = [ @@ -2131,6 +2175,7 @@ def OMP_target_teams_loop : Directive<"target teams loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "teams", "loop"]; } def OMP_parallel_loop : Directive<"parallel loop"> { let allowedClauses = [ @@ -2152,6 +2197,7 @@ def OMP_parallel_loop : Directive<"parallel loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "loop"]; } def OMP_target_parallel_loop : Directive<"target parallel loop"> { let allowedClauses = [ @@ -2183,11 +2229,13 @@ def OMP_target_parallel_loop : Directive<"target parallel loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel", "loop"]; } def OMP_Metadirective : Directive<"metadirective"> { let allowedClauses = [VersionedClause]; let allowedOnceClauses = [VersionedClause]; } + def OMP_Unknown : Directive<"unknown"> { let isDefault = true; } diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h index c86018715a48a1..88fef74e298bf4 100644 --- a/llvm/include/llvm/TableGen/DirectiveEmitter.h +++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h @@ -121,6 +121,10 @@ class Directive : public BaseRecord { std::vector getRequiredClauses() const { return Def->getValueAsListOfDefs("requiredClauses"); } + + std::vector getLeafConstructNames() const { + return Def->getValueAsListOfStrings("leafs"); + } }; // Wrapper class that contains Clause's information defined in DirectiveBase.td diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index b6aee665f8ee0b..232b1a4e6f7b5e 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -186,6 +186,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { if (DirLang.hasEnableBitmaskEnumInNamespace()) OS << "\n#include \"llvm/ADT/BitmaskEnum.h\"\n"; + OS << "#include \"llvm/ADT/SmallVector.h\"\n"; OS << "\n"; OS << "namespace llvm {\n"; @@ -231,6 +232,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; + OS << "const llvm::SmallVector &getLeafConstructs(Directive D);\n"; if (EnumHelperFuncs.length() > 0) { OS << EnumHelperFuncs; OS << "\n"; @@ -435,6 +437,82 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, OS << "}\n"; // End of function isAllowedClauseForDirective } +// Generate the getLeafConstructs function implementation. +static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang, + raw_ostream &OS) { + llvm::StringMap NameToRec; + for (Record *R : DirLang.getDirectives()) + NameToRec.insert(std::make_pair(BaseRecord(R).getName(), R)); + + auto getQualifiedName = [&](StringRef Formatted) -> std::string { + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + + "::Directive::" + DirLang.getDirectivePrefix() + Formatted) + .str(); + }; + + // For each list of leafs, generate a static local object, then + // return a reference to that object for a given directive, e.g. + // + // static ListTy leafConstructs_A_B = { A, B }; + // static ListTy leafConstructs_C_D_E = { C, D, E }; + // switch (Dir) { + // case A_B: + // return leafConstructs_A_B; + // case C_D_E: + // return leafConstructs_C_D_E; + + // Map from a record that defines a directive to the name of the + // local object with the list of its leafs. + DenseMap ListNames; + + std::string DirectiveTypeName = + std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive"; + std::string DirectiveListTypeName = + std::string("llvm::SmallVector<") + DirectiveTypeName + ">"; + + // const Container &llvm::::GetLeafConstructs(llvm::::Directive Dir) + OS << "const " << DirectiveListTypeName + << " &llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs(" + << DirectiveTypeName << " Dir) "; + OS << "{\n"; + + // Generate the locals. + for (auto &[_, R] : NameToRec) { + Directive Dir{R}; + + std::vector LeafNames = Dir.getLeafConstructNames(); + if (LeafNames.empty()) + continue; + + std::string ListName = "leafConstructs_" + Dir.getFormattedName(); + OS << " static " << DirectiveListTypeName << ' ' << ListName << " {\n"; + for (StringRef L : LeafNames) { + Directive LeafDir{NameToRec.at(L)}; + OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n"; + } + OS << " };\n"; + ListNames.insert(std::make_pair(R, std::move(ListName))); + } + + OS << " static " << DirectiveListTypeName << " nothing {};\n"; + + OS << '\n'; + OS << " switch (Dir) {\n"; + for (auto &[_, R] : NameToRec) { + auto F = ListNames.find(R); + if (F == ListNames.end()) + continue; + + Directive Dir{R}; + OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n"; + OS << " return " << F->second << ";\n"; + } + OS << " default:\n"; + OS << " return nothing;\n"; + OS << " } // switch (Dir)\n"; + OS << "}\n"; +} + // Generate a simple enum set with the give clauses. static void GenerateClauseSet(const std::vector &Clauses, raw_ostream &OS, StringRef ClauseSetPrefix, @@ -876,6 +954,9 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang, // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(DirLang, OS); + + // getLeafConstructs(Directive D) + GenerateGetLeafConstructs(DirLang, OS); } // Generate the implemenation section for the enumeration in the directive From c06c3b5cf2641ef43bcfae3368086ad5b43850b1 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 29 Jan 2024 09:45:25 -0600 Subject: [PATCH 2/4] [Flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering Right now attributes like OpenMP version or target attributes for offload are set after lowering in bbc. The flang frontend sets them before lowering, making them available in the lowering process. This change sets them before lowering in bbc as well. --- flang/tools/bbc/bbc.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index 98d9258e023e55..b3a8c9fb80f9cf 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -342,7 +342,6 @@ static mlir::LogicalResult convertFortranSourceToMLIR( semanticsContext.targetCharacteristics(), parsing.allCooked(), targetTriple, kindMap, loweringOptions, {}, semanticsContext.languageFeatures(), targetMachine); - burnside.lower(parseTree, semanticsContext); mlir::ModuleOp mlirModule = burnside.getModule(); if (enableOpenMP) { if (enableOpenMPGPU && !enableOpenMPDevice) { @@ -358,6 +357,7 @@ static mlir::LogicalResult convertFortranSourceToMLIR( setOffloadModuleInterfaceAttributes(mlirModule, offloadModuleOpts); setOpenMPVersionAttribute(mlirModule, setOpenMPVersion); } + burnside.lower(parseTree, semanticsContext); std::error_code ec; std::string outputName = outputFilename; if (!outputName.size()) From 900187f8e7d4778d42830961e52b7e977e3ab38f Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 19 Jan 2024 16:07:12 -0600 Subject: [PATCH 3/4] getOpenMPVersion --- flang/lib/Lower/OpenMP.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index be2117efbabc0a..419eb80ab4631d 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -48,6 +48,11 @@ using DeclareTargetCapturePair = // Common helper functions //===----------------------------------------------------------------------===// +static uint32_t getOpenMPVersion(mlir::ModuleOp mod) { + mlir::Attribute verAttr = mod->getAttr("omp.version"); + return llvm::cast(verAttr).getVersion(); +} + static Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { Fortran::semantics::Symbol *sym = nullptr; From 5ae6a2021aa8d449f752b72010c26c0aa5e7cd69 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 18 Jan 2024 07:48:25 -0600 Subject: [PATCH 4/4] Split function complete Some TODOs still remain. --- flang/lib/Lower/OpenMP.cpp | 1125 +++++++++++++++++++++++++++++++++++- 1 file changed, 1101 insertions(+), 24 deletions(-) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 419eb80ab4631d..be9531b574c605 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/CommandLine.h" @@ -48,32 +49,76 @@ using DeclareTargetCapturePair = // Common helper functions //===----------------------------------------------------------------------===// -static uint32_t getOpenMPVersion(mlir::ModuleOp mod) { - mlir::Attribute verAttr = mod->getAttr("omp.version"); - return llvm::cast(verAttr).getVersion(); +static llvm::ArrayRef getWorksharing() { + static llvm::omp::Directive worksharing[] = { + llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_for, + llvm::omp::Directive::OMPD_scope, llvm::omp::Directive::OMPD_sections, + llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare, + }; + return worksharing; } -static Fortran::semantics::Symbol * -getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (auto *arrayEle = - Fortran::parser::Unwrap( - designator)) { - sym = GetFirstName(arrayEle->base).symbol; - } else if (auto *structComp = Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>(designator)) { - sym = structComp->component.symbol; - } else if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - sym = name->symbol; - } - }, - [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, - ompObject.u); +static llvm::ArrayRef getWorksharingLoop() { + static llvm::omp::Directive worksharingLoop[] = { + llvm::omp::Directive::OMPD_do, + llvm::omp::Directive::OMPD_for, + }; + return worksharingLoop; +} + +static uint32_t getOpenMPVersion(const mlir::ModuleOp &mod) { + if (mlir::Attribute verAttr = mod->getAttr("omp.version")) + return llvm::cast(verAttr).getVersion(); + llvm_unreachable("Exoecting OpenMP version attribute in module"); +} + +static std::pair +getOmpObjectSymbolAndBase(const Fortran::parser::Name &name) { + return std::make_pair(name.symbol, nullptr); +} + +static std::pair +getOmpObjectSymbolAndBase(const Fortran::parser::Designator &designator) { + if (auto *arrayEle = + Fortran::parser::Unwrap(designator)) { + auto *sym = GetFirstName(arrayEle->base).symbol; + // Array elements don't have their own symbols, instead the base symbol + // is used. + return std::make_pair(sym, sym); + } + if (auto *structComp = + Fortran::parser::Unwrap( + designator)) { + auto *sym = structComp->component.symbol; + auto *base = GetFirstName(structComp->base).symbol; + return std::make_pair(sym, base); + } + if (const Fortran::parser::Name *name = + Fortran::semantics::getDesignatorNameIfDataRef(designator)) { + return getOmpObjectSymbolAndBase(*name); + } + llvm_unreachable("Cannot obtain symbols for designtor"); +} + +static std::pair +getOmpObjectSymbolAndBase(const Fortran::parser::OmpObject &object) { + std::pair syms; + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + syms = getOmpObjectSymbolAndBase(designator); + }, + [&](const Fortran::parser::Name &name) { + syms = getOmpObjectSymbolAndBase(name); + }}, + object.u); + return syms; +} + +template +static Fortran::semantics::Symbol *getOmpObjectSymbol(Object &&object) { + auto *sym = + std::get<0>(getOmpObjectSymbolAndBase(std::forward(object))); + assert(sym != nullptr); return sym; } @@ -147,6 +192,1034 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter, converter.genEval(e); } +//===----------------------------------------------------------------------===// +// Directive decomposition +//===----------------------------------------------------------------------===// + +namespace { +struct DirectiveInfo { + llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown; + llvm::SmallVector clauses; +}; + +struct CompositeInfo { + CompositeInfo(const mlir::ModuleOp &modOp, + Fortran::lower::pft::Evaluation &ev, + llvm::omp::Directive compDir, + const std::list &clauses); + using ClauseSet = std::set; + + bool split(); + void addClause(const Fortran::parser::OmpClause *clause); + + DirectiveInfo *findDirective(llvm::omp::Directive dirId) { + for (DirectiveInfo &dir : leafs) { + if (dir.id == dirId) + return &dir; + } + return nullptr; + } + ClauseSet *findClauses(const Fortran::parser::OmpObject &object) { + const Fortran::semantics::Symbol *sym = getOmpObjectSymbol(object); + if (auto found = syms.find(sym); found != syms.end()) + return &found->second; + return nullptr; + } + + llvm::SmallVector leafs; // Ordered outer to inner. + llvm::DenseMap syms; + llvm::DenseSet mapBases; + Fortran::lower::pft::Evaluation &eval; + const mlir::ModuleOp &mod; + + // List of clauses applied to the combined/composite directive. + // Processing of the LINEAR clause can result in FIRSTPRIVATE and/or + // LASTPRIVATE added to this list. + llvm::SmallVector clauses; + // Storage for the OmpClause's created during splitting. + llvm::SmallVector> storage; +}; +} // namespace + +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const DirectiveInfo &dirInfo); +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CompositeInfo &compInfo); + +namespace detail { +template +llvm::omp::Clause getClauseIdForClass(C &&) { + using namespace Fortran; + using A = llvm::remove_cvref_t; // A is referenced in OMP.inc +#define GEN_FLANG_CLAUSE_PARSER_KIND_MAP +#include "llvm/Frontend/OpenMP/OMP.inc" +} + +template +typename std::remove_reference_t::iterator +find_unique(Container &&container, Predicate &&pred) { + auto first = std::find_if(container.begin(), container.end(), pred); + if (first == container.end()) + return first; + auto second = std::find_if(std::next(first), container.end(), pred); + if (second == container.end()) + return first; + return container.end(); +} +} // namespace detail + +static llvm::omp::Clause getClauseId(const Fortran::parser::OmpClause &clause) { + return std::visit([](auto &&s) { return detail::getClauseIdForClass(s); }, + clause.u); +} + +namespace detail { +template +auto clauseDispatch( + Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, CompositeInfo &compInfo, + Handler &&handler, + typename llvm::remove_cvref_t::EmptyTrait value = {}) { + return handler(std::forward(clause), clauseId, clauseNode, compInfo); +} + +template +auto clauseDispatch( + Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, CompositeInfo &compInfo, + Handler &&handler, + typename llvm::remove_cvref_t::WrapperTrait value = {}) { + return handler(std::forward(clause), clauseId, clause.v, clauseNode, + compInfo); +} + +template +auto visit_clause(const Fortran::parser::OmpClause &clause, Handler &&handler, + CompositeInfo &compInfo) { + return std::visit( + [&](auto &&actual) { + return detail::clauseDispatch(actual, getClauseId(clause), &clause, + compInfo, handler); + }, + clause.u); +} +} // namespace detail + +static Fortran::semantics::Symbol * +getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) { + return eval.visit(Fortran::common::visitors{ + [&](const Fortran::parser::DoConstruct &doLoop) { + if (const auto &maybeCtrl = doLoop.GetLoopControl()) { + using LoopControl = Fortran::parser::LoopControl; + if (auto *bounds = std::get_if(&maybeCtrl->u)) { + static_assert( + std::is_same_vname), + Fortran::parser::Scalar>); + return bounds->name.thing.symbol; + } + } + return static_cast(nullptr); + }, + [](auto &&) { + return static_cast(nullptr); + }, + }); +} + +static void addSymsToMap(const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + for (const Fortran::parser::OmpObject &object : objects.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(object); + compInfo.syms[sym].insert(clauseNode); + } +} + +static void addSymToMap(const Fortran::parser::Name &name, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(name); + compInfo.syms[sym].insert(clauseNode); +} + +static void addSymToMap(const Fortran::parser::Designator &designator, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(designator); + compInfo.syms[sym].insert(clauseNode); +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Do nothing for clauses represented by empty classes. +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + addSymsToMap(objects, clauseNode, compInfo); +} + +static void +addClauseSymsToMap(const Fortran::parser::OmpClause::Aligned &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpAlignedClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<0> -> OmpObjectList + addSymsToMap(std::get<0>(contents.t), clauseNode, compInfo); +} + +static void +addClauseSymsToMap(const Fortran::parser::OmpClause::Allocate &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpAllocateClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<1> -> OmpObjectList + addSymsToMap(std::get<1>(contents.t), clauseNode, compInfo); +} + +static void addClauseSymsToMap(const Fortran::parser::OmpClause::Depend &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpDependClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::OmpDependClause::InOut &inout) { + // inout.t<1> -> std::list + for (const auto &designator : std::get<1>(inout.t)) + addSymToMap(designator, clauseNode, compInfo); + }, + [](auto &&) { + // No objects in the other alternatives. + }, + }, + contents.u); +} + +static void +addClauseSymsToMap(const Fortran::parser::OmpClause::InReduction &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpInReductionClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<1> -> OmpObjectList + addSymsToMap(std::get<1>(contents.t), clauseNode, compInfo); +} + +static void addClauseSymsToMap(const Fortran::parser::OmpClause::Linear &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpLinearClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.u is a variant where both alternatives have member `names` + // that is std::list. + std::visit( + [&](auto &&s) { + for (const Fortran::parser::Name &name : s.names) + addSymToMap(name, clauseNode, compInfo); + }, + contents.u); +} + +static void addClauseSymsToMap(const Fortran::parser::OmpClause::Map &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpMapClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<1> -> OmpObjectList + const Fortran::parser::OmpObjectList &objects = std::get<1>(contents.t); + + addSymsToMap(objects, clauseNode, compInfo); + + // Additionally, add base symbols to the 'mapBases' set. + for (const Fortran::parser::OmpObject &object : objects.v) { + if (auto *base = std::get<1>(getOmpObjectSymbolAndBase(object))) + compInfo.mapBases.insert(base); + } +} + +template +static void +addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpReductionClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Reduction, TaskReduction + // contents.t<1> -> OmpObjectList + addSymsToMap(std::get<1>(contents.t), clauseNode, compInfo); +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const std::list &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // NonTemporal, Uniform + for (const Fortran::parser::Name &name : contents) + addSymToMap(name, clauseNode, compInfo); +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + WrappedType &&wrapped, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Make sure that we are not missing anything: list all the wrapped + // types that do not contain (or reference) any objects. + using namespace Fortran::parser; + static_assert( + llvm::is_one_of< + llvm::remove_cvref_t, OmpAtomicDefaultMemOrderClause, + ConstantExpr, ScalarIntConstantExpr, ScalarIntExpr, ScalarLogicalExpr, + std::optional, std::optional, + std::list, OmpDefaultClause, OmpDefaultmapClause, + OmpDeviceClause, OmpDeviceTypeClause, OmpIfClause, OmpOrderClause, + OmpProcBindClause, OmpScheduleClause>::value); +} + +CompositeInfo::CompositeInfo( + const mlir::ModuleOp &modOp, Fortran::lower::pft::Evaluation &ev, + llvm::omp::Directive compDir, + const std::list &clauses) + : eval(ev), mod(modOp) { + for (llvm::omp::Directive dir : llvm::omp::getLeafConstructs(compDir)) + leafs.push_back(DirectiveInfo{dir}); + + for (const Fortran::parser::OmpClause &clause : clauses) + addClause(&clause); +} + +void CompositeInfo::addClause(const Fortran::parser::OmpClause *clause) { + clauses.push_back(clause); + detail::visit_clause( + *clause, [](auto &&...args) { addClauseSymsToMap(args...); }, *this); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const DirectiveInfo &dirInfo) { + os << llvm::omp::getOpenMPDirectiveName(dirInfo.id); + for (auto [index, clause] : llvm::enumerate(dirInfo.clauses)) { + os << (index == 0 ? '\t' : ' '); + os << llvm::omp::getOpenMPClauseName(getClauseId(*clause)); + } + return os; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CompositeInfo &compInfo) { + for (const auto &[index, dirInfo] : llvm::enumerate(compInfo.leafs)) + os << "leaf[" << index << "]: " << dirInfo << '\n'; + + os << "syms:\n"; + for (const auto &[sym, clauses] : compInfo.syms) { + os << *sym << " -> {"; + for (const auto *clause : clauses) + os << ' ' << llvm::omp::getOpenMPClauseName(getClauseId(*clause)); + os << " }\n"; + } + os << "mapBases: {"; + for (const auto &sym : compInfo.mapBases) + os << ' ' << *sym; + os << " }\n"; + return os; +} + +// Apply a clause to the only directive that allows it. If there are no +// directives that allow it, or if there is more that one, do not apply +// anything and return false, otherwise return true. +static bool applyToUnique(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + uint32_t version = getOpenMPVersion(compInfo.mod); + auto unique = detail::find_unique(compInfo.leafs, [=](const auto &dirInfo) { + return llvm::omp::isAllowedClauseForDirective(dirInfo.id, clauseId, + version); + }); + + if (unique != compInfo.leafs.end()) { + unique->clauses.push_back(clauseNode); + return true; + } + return false; +} + +// Apply a clause to the first directive in given range that allows it. +// If such a directive does not exist, return false, otherwise return true. +template +static bool applyToFirst(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + const mlir::ModuleOp &mod, + llvm::iterator_range range) { + if (range.empty()) + return false; + + uint32_t version = getOpenMPVersion(mod); + for (DirectiveInfo &dir : range) { + if (!llvm::omp::isAllowedClauseForDirective(dir.id, clauseId, version)) + continue; + dir.clauses.push_back(clauseNode); + return true; + } + return false; +} + +// Apply a clause to the innermost directive that allows it. If such a +// directive does not exist, return false, otherwise return true. +static bool applyToInnermost(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + return applyToFirst(clauseId, clauseNode, compInfo.mod, + llvm::reverse(compInfo.leafs)); +} + +// Apply a clause to the outermost directive that allows it. If such a +// directive does not exist, return false, otherwise return true. +static bool applyToOutermost(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + return applyToFirst(clauseId, clauseNode, compInfo.mod, + llvm::iterator_range(compInfo.leafs)); +} + +template +static bool applyIf(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo, + Predicate shouldApply) { + bool applied = false; + uint32_t version = getOpenMPVersion(compInfo.mod); + for (DirectiveInfo &dir : compInfo.leafs) { + if (!llvm::omp::isAllowedClauseForDirective(dir.id, clauseId, version)) + continue; + if (!shouldApply(dir)) + continue; + dir.clauses.push_back(clauseNode); + applied = true; + } + + return applied; +} + +static bool applyToAll(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + return applyIf(clauseId, clauseNode, compInfo, [](auto) { return true; }); +} + +template +static bool applyClause(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // The default behavior is to find the unique directive to which the + // given clause may be applied. If there are no such directives, or + // if there are multiple ones, flag an error. + // From "OpenMP Application Programming Interface", Version 5.2: + // "Some clauses are permitted only on a single leaf construct of the + // combined or composite construct, in which case the effect is as if + // the clause is applied to that specific construct." (p339, 31-33) + if (applyToUnique(clauseId, clauseNode, compInfo)) + return true; + + llvm::errs() << "handle empty class:" + << llvm::omp::getOpenMPClauseName(clauseId) << '\n'; + return false; +} + +// Clauses that expected to only be applicable to a single leaf construct. +template +static bool applyClause(Clause &&clause, llvm::omp::Clause clauseId, + WrappedType &&, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + if (applyToUnique(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "handle wrapper class for generic type:" + << llvm::omp::getOpenMPClauseName(clauseId) << '\n'; + return false; +} + +// COLLAPSE +template +static bool applyClause(const Fortran::parser::OmpClause::Collapse &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply COLLAPSE to the innermost directive. If it's not one that + // allows it flag an error. + if (!compInfo.leafs.empty()) { + DirectiveInfo &last = compInfo.leafs.back(); + uint32_t version = getOpenMPVersion(compInfo.mod); + + if (llvm::omp::isAllowedClauseForDirective(last.id, clauseId, version)) { + last.clauses.push_back(clauseNode); + return true; + } + } + + llvm::errs() << "Cannot apply COLLAPSE\n"; + return false; +} + +// PRIVATE +static bool applyClause(const Fortran::parser::OmpClause::Private &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + if (applyToInnermost(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply PRIVATE\n"; + return false; +} + +// FIRSTPRIVATE +static bool applyClause(const Fortran::parser::OmpClause::Firstprivate &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + bool applied = false; + + // S Section 17.2 + // S The effect of the firstprivate clause is as if it is applied to one + // S or more leaf constructs as follows: + + // S - To the distribute construct if it is among the constituent constructs; + // S - To the teams construct if it is among the constituent constructs and + // S the distribute construct is not; + auto hasDistribute = compInfo.findDirective(llvm::omp::OMPD_distribute); + auto hasTeams = compInfo.findDirective(llvm::omp::OMPD_teams); + if (hasDistribute != nullptr) { + hasDistribute->clauses.push_back(clauseNode); + applied = true; + // S If the teams construct is among the constituent constructs and the + // S effect is not as if the firstprivate clause is applied to it by the + // S above rules, then the effect is as if the shared clause with the + // S same list item is applied to the teams construct. + if (hasTeams != nullptr) { + // TODO: Apply SHARED(objects) + } + } else if (hasTeams != nullptr) { + hasTeams->clauses.push_back(clauseNode); + applied = true; + } + + // S - To a worksharing construct that accepts the clause if one is among + // S the constituent constructs; + auto findWorksharing = [&]() { + auto worksharing = getWorksharing(); + for (DirectiveInfo &dir : compInfo.leafs) { + auto found = llvm::find(worksharing, dir.id); + if (found != std::end(worksharing)) + return &dir; + } + return static_cast(nullptr); + }; + + auto hasWorksharing = findWorksharing(); + if (hasWorksharing != nullptr) { + hasWorksharing->clauses.push_back(clauseNode); + applied = true; + } + + // S - To the taskloop construct if it is among the constituent constructs; + auto hasTaskloop = compInfo.findDirective(llvm::omp::OMPD_taskloop); + if (hasTaskloop != nullptr) { + hasTaskloop->clauses.push_back(clauseNode); + applied = true; + } + + // S - To the parallel construct if it is among the constituent constructs + // S and neither a taskloop construct nor a worksharing construct that + // S accepts the clause is among them; + auto hasParallel = compInfo.findDirective(llvm::omp::OMPD_parallel); + if (hasParallel != nullptr) { + if (hasTaskloop == nullptr && hasWorksharing == nullptr) { + hasParallel->clauses.push_back(clauseNode); + applied = true; + } else { + // S If the parallel construct is among the constituent constructs and + // S the effect is not as if the firstprivate clause is applied to it by + // S the above rules, then the effect is as if the shared clause with + // S the same list item is applied to the parallel construct. + // TODO: apply SHARED(objects) to PARALLEL + } + } + + // S - To the target construct if it is among the constituent constructs + // S and the same list item neither appears in a lastprivate clause nor + // S is the base variable or base pointer of a list item that appears in + // S a map clause. + auto objInLastprivate = [&](const Fortran::parser::OmpObject &object) { + if (CompositeInfo::ClauseSet *clauses = compInfo.findClauses(object)) { + for (const Fortran::parser::OmpClause *clause : *clauses) { + if (getClauseId(*clause) == llvm::omp::Clause::OMPC_lastprivate) + return true; + } + } + return false; + }; + + auto hasTarget = compInfo.findDirective(llvm::omp::OMPD_target); + if (hasTarget != nullptr) { + for (const Fortran::parser::OmpObject &object : objects.v) { + if (objInLastprivate(object)) + continue; + if (compInfo.mapBases.contains(getOmpObjectSymbol(object))) + continue; + // TODO: Add FIRSTPRIVATE(object) to clause list + // TODO: may need a new OmpObjectList here + applied = true; + } + } + + return applied; +} + +// LASTPRIVATE +static bool applyClause(const Fortran::parser::OmpClause::Lastprivate &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the lastprivate clause is as if it is applied to all leaf + // S constructs that permit the clause. + if (!applyToAll(clauseId, clauseNode, compInfo)) { + llvm::errs() << "Cannot apply LASTPRIVATE\n"; + return false; + } + + // S If the parallel construct is among the constituent constructs and the + // S list item is not also specified in the firstprivate clause, then the + // S effect of the lastprivate clause is as if the shared clause with the + // S same list item is applied to the parallel construct. + auto inFirstprivate = [&](const Fortran::parser::OmpObject &object) { + if (auto *clauses = compInfo.findClauses(object)) { + for (const Fortran::parser::OmpClause *clause : *clauses) { + if (getClauseId(*clause) == llvm::omp::Clause::OMPC_firstprivate) + return true; + } + } + return false; + }; + + if (auto hasParallel = compInfo.findDirective(llvm::omp::OMPD_parallel)) { + for (const Fortran::parser::OmpObject &object : objects.v) { + if (!inFirstprivate(object)) { + // TODO Add SHARED(object) to PARALLEL + } + } + } + + // S If the teams construct is among the constituent constructs and the + // S list item is not also specified in the firstprivate clause, then the + // S effect of the lastprivate clause is as if the shared clause with the + // S same list item is applied to the teams construct. + if (auto hasTeams = compInfo.findDirective(llvm::omp::OMPD_teams)) { + for (const Fortran::parser::OmpObject &object : objects.v) { + if (!inFirstprivate(object)) { + // TODO Add SHARED(object) to TEAMS + } + } + } + + // S If the target construct is among the constituent constructs and the + // S list item is not the base variable or base pointer of a list item that + // S appears in a map clause, the effect of the lastprivate clause is as if + // S the same list item appears in a map clause with a map-type of tofrom. + if (auto hasTarget = compInfo.findDirective(llvm::omp::OMPD_target)) { + for (const Fortran::parser::OmpObject &object : objects.v) { + const Fortran::semantics::Symbol *sym = getOmpObjectSymbol(object); + // See if symbol is a base symbol in MAP. + if (!compInfo.mapBases.contains(sym)) { + // TODO Add MAP(tofrom, object) to TARGET. + } + } + } + + return false; +} + +// SHARED +template +static bool applyClause(const Fortran::parser::OmpClause::Shared &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply SHARED to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply SHARED\n"; + return false; +} + +// DEFAULT +template +static bool applyClause(const Fortran::parser::OmpClause::Default &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply DEFAULT to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply DEFAULT\n"; + return false; +} + +// THREAD_LIMIT +template +static bool applyClause(const Fortran::parser::OmpClause::ThreadLimit &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply THREAD_LIMIT to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply THREAD_LIMIT\n"; + return false; +} + +// ORDER +template +static bool applyClause(const Fortran::parser::OmpClause::Order &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply ORDER to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply ORDER\n"; + return false; +} + +// ALLOCATE +template +static bool applyClause(const Fortran::parser::OmpClause::Allocate &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the allocate clause is as if it is applied to all leaf + // S constructs that permit the clause and to which a data-sharing attribute + // S clause that may create a private copy of the same list item is applied. + + // XXX This one may need to be applied at the end, once we know which leaf + // constructs have what data-sharing attributes. Or maybe do all data-sharing + // first, then the rest of the clauses? + + // TODO + llvm::errs() << "Cannot apply ALLOCATE\n"; + return false; +} + +// REDUCTION +static bool applyClause(const Fortran::parser::OmpClause::Reduction &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpReductionClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the reduction clause is as if it is applied to all leaf + // S constructs that permit the clause, except for the following constructs: + // S - The parallel construct, when combined with the sections, worksharing- + // S loop, loop, or taskloop construct; and + // S - The teams construct, when combined with the loop construct. + bool applyToParallel = true, applyToTeams = true; + + auto hasParallel = + compInfo.findDirective(llvm::omp::Directive::OMPD_parallel); + if (hasParallel) { + auto exclusions = llvm::concat( + getWorksharingLoop(), llvm::ArrayRef{ + llvm::omp::Directive::OMPD_loop, + llvm::omp::Directive::OMPD_sections, + llvm::omp::Directive::OMPD_taskloop, + }); + auto present = [&](llvm::omp::Directive id) { + return compInfo.findDirective(id) != nullptr; + }; + + if (llvm::any_of(exclusions, present)) + applyToParallel = false; + } + + auto hasTeams = compInfo.findDirective(llvm::omp::Directive::OMPD_teams); + if (hasTeams) { + // The only exclusion is OMPD_loop. + if (compInfo.findDirective(llvm::omp::Directive::OMPD_loop)) + applyToTeams = false; + } + + // S For the parallel and teams constructs above, the effect of the + // S reduction clause instead is as if each list item or, for any list + // S item that is an array item, its corresponding base array or base + // S pointer appears in a shared clause for the construct. + for (auto dir : {hasParallel, hasTeams}) { + if (dir == nullptr) + continue; + // TODO apply SHARED(objects) to *dir. + } + + // TODO: Apply the following. + // S If the task reduction-modifier is specified, the effect is as if + // S it only modifies the behavior of the reduction clause on the innermost + // S leaf construct that accepts the modifier (see Section 5.5.8). If the + // S inscan reduction-modifier is specified, the effect is as if it modifies + // S the behavior of the reduction clause on all constructs of the combined + // S construct to which the clause is applied and that accept the modifier. + + bool applied = + applyIf(clauseId, clauseNode, compInfo, [&](DirectiveInfo &dir) { + if (!applyToParallel && &dir == hasParallel) + return false; + if (!applyToTeams && &dir == hasTeams) + return false; + return true; + }); + + // TODO: Apply the following. + // S If a list item in a reduction clause on a combined target construct + // S does not have the same base variable or base pointer as a list item + // S in a map clause on the construct, then the effect is as if the list + // S item in the reduction clause appears as a list item in a map clause + // S with a map-type of tofrom. + + return applied; +} + +// IF +static bool applyClause(const Fortran::parser::OmpClause::If &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpIfClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + using namespace Fortran::parser; + auto &modifier = + std::get>(contents.t); + + if (modifier) { + llvm::omp::Directive dirId = llvm::omp::Directive::OMPD_unknown; + + switch (*modifier) { + case OmpIfClause::DirectiveNameModifier::Parallel: + dirId = llvm::omp::Directive::OMPD_parallel; + break; + case OmpIfClause::DirectiveNameModifier::Simd: + dirId = llvm::omp::Directive::OMPD_simd; + break; + case OmpIfClause::DirectiveNameModifier::Target: + dirId = llvm::omp::Directive::OMPD_target; + break; + case OmpIfClause::DirectiveNameModifier::Task: + dirId = llvm::omp::Directive::OMPD_task; + break; + case OmpIfClause::DirectiveNameModifier::Taskloop: + dirId = llvm::omp::Directive::OMPD_taskloop; + break; + case OmpIfClause::DirectiveNameModifier::Teams: + dirId = llvm::omp::Directive::OMPD_teams; + break; + + case OmpIfClause::DirectiveNameModifier::TargetData: + case OmpIfClause::DirectiveNameModifier::TargetEnterData: + case OmpIfClause::DirectiveNameModifier::TargetExitData: + case OmpIfClause::DirectiveNameModifier::TargetUpdate: + default: + llvm::errs() << "Invalid modifier in IF clause\n"; + return false; + } + + if (auto *hasDir = compInfo.findDirective(dirId)) { + hasDir->clauses.push_back(clauseNode); + return true; + } + llvm::errs() << "Directive from modifier not found\n"; + return false; + } + + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + + llvm::errs() << "Cannot apply IF\n"; + return false; +} + +// LINEAR +static bool applyClause(const Fortran::parser::OmpClause::Linear &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpLinearClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the linear clause is as if it is applied to the innermost + // S leaf construct. + if (applyToInnermost(clauseId, clauseNode, compInfo)) { + llvm::errs() << "Cannot apply LINEAR\n"; + return false; + } + + // The rest is about SIMD. + if (!compInfo.findDirective(llvm::omp::OMPD_simd)) + return true; + + const std::list &names = std::visit( + [](auto &&s) { + // Both alternatives have member "names". + return s.names; + }, + contents.u); + Fortran::semantics::Symbol *iterVarSym = + getIterationVariableSymbol(compInfo.eval); + + // S Additionally, if the list item is not the iteration variable of a + // S simd or worksharing-loop SIMD construct, the effect on the outer leaf + // S constructs is as if the list item was specified in firstprivate and + // S lastprivate clauses on the combined or composite construct, [...] + // + // S If a list item of the linear clause is the iteration variable of a + // S simd or worksharing-loop SIMD construct and it is not declared in + // S the construct, the effect on the outer leaf constructs is as if the + // S list item was specified in a lastprivate clause on the combined or + // S composite construct [...] + + // It's not clear how an object can be listed in a clause AND be the + // iteration variable of a construct in which is it declared. If an + // object is declared in the construct, then the declaration is located + // after the clause listing it. + + // Lists of objects that will be used to construct FIRSTPRIVATE and + // LASTPRIVATE clauses. + std::list first, last; + + auto makeObjectFromName = [](Fortran::parser::Name name) { + // Pass "name" by copy. + static_assert(!std::is_lvalue_reference_v); + + auto source = name.source; + Fortran::parser::Designator designator( + Fortran::parser::DataRef(std::move(name))); + designator.source = source; + + return Fortran::parser::OmpObject(std::move(designator)); + }; + + for (const Fortran::parser::Name &name : names) { + last.emplace_back(makeObjectFromName(name)); + if (getOmpObjectSymbol(name) != iterVarSym) + first.emplace_back(makeObjectFromName(name)); + } + + auto addClause = [&](auto &&specific) { + // Take a specific clause, i.e. Fortran::parse::OmpClause::Xyz, + // wrap it into a general OmpClause, and add it to compInfo. + auto general = + std::make_unique(std::move(specific)); + compInfo.storage.emplace_back(std::move(general)); + compInfo.addClause(compInfo.storage.back().get()); + }; + + if (!first.empty()) { + Fortran::parser::OmpObjectList objList(std::move(first)); + addClause(Fortran::parser::OmpClause::Firstprivate(std::move(objList))); + } + if (!last.empty()) { + Fortran::parser::OmpObjectList objList(std::move(first)); + addClause(Fortran::parser::OmpClause::Lastprivate(std::move(objList))); + } + + return true; +} + +// NOWAIT +static bool applyClause(const Fortran::parser::OmpClause::Nowait &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + if (applyToOutermost(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply NOWAIT\n"; + return false; +} + +bool CompositeInfo::split() { + bool success = true; + + // First we need to apply LINEAR, because it can generate additional + // FIRSTPRIVATE and LASTPRIVATE clauses that apply to the compound/ + // composite construct. + // Collect them separately, because they may modify the clause list. + llvm::SmallVector linears; + for (const auto *clause : clauses) { + if (getClauseId(*clause) == llvm::omp::Clause::OMPC_linear) + linears.push_back(clause); + } + for (const auto *clause : linears) { + success = success && + detail::visit_clause( + *clause, [&](auto &&...args) { return applyClause(args...); }, + *this); + } + + // ALLOCATE clauses need to be applied last since they need to see + // which directives have data-privatizing clauses. + auto skip = [&](auto *clause) { + switch (getClauseId(*clause)) { + case llvm::omp::Clause::OMPC_allocate: + case llvm::omp::Clause::OMPC_linear: + return true; + default: + return false; + } + }; + + // Apply (almost) all clauses. + for (const auto *clause : clauses) { + if (skip(clause)) + continue; + success = success && + detail::visit_clause( + *clause, [&](auto &&...args) { return applyClause(args...); }, + *this); + } + + // Apply ALLOCATE. + for (const auto *clause : clauses) { + if (getClauseId(*clause) != llvm::omp::Clause::OMPC_allocate) + continue; + success = success && + detail::visit_clause( + *clause, [&](auto &&...args) { return applyClause(args...); }, + *this); + } + + return success; +} + +static void +splitCompositeConstruct(const mlir::ModuleOp &modOp, + Fortran::lower::pft::Evaluation &eval, + llvm::omp::Directive compDir, + const std::list &clauses) { + llvm::errs() << "composite name:" + << llvm::omp::getOpenMPDirectiveName(compDir) << '\n'; + llvm::errs() << "clause list:"; + for (auto &clause : clauses) + llvm::errs() << ' ' << llvm::omp::getOpenMPClauseName(getClauseId(clause)); + llvm::errs() << '\n'; + + CompositeInfo compInfo(modOp, eval, compDir, clauses); + llvm::errs() << "compInfo.1\n" << compInfo << '\n'; + + bool success = compInfo.split(); + + // Dump + llvm::errs() << "success:" << success << '\n'; + llvm::errs() << "compInfo.2\n" << compInfo << '\n'; +} + //===----------------------------------------------------------------------===// // DataSharingProcessor //===----------------------------------------------------------------------===// @@ -3352,6 +4425,10 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); + // Test call + splitCompositeConstruct(converter.getFirOpBuilder().getModule(), eval, + std::get<0>(beginLoopDirective.t).v, + std::get<1>(beginLoopDirective.t).v); const auto &loopOpClauseList = std::get(beginLoopDirective.t); mlir::Location currentLocation =