diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index d257da1a70964..d5713cfe420a2 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -148,6 +148,14 @@ inline Expr AsGenericExpr(Expr &&x) { return std::move(x); } std::optional> AsGenericExpr(DataRef &&); std::optional> AsGenericExpr(const Symbol &); +// Propagate std::optional from input to output. +template +std::optional> AsGenericExpr(std::optional &&x) { + if (!x) + return std::nullopt; + return AsGenericExpr(std::move(*x)); +} + template common::IfNoLvalue::category>>, A> AsCategoryExpr( A &&x) { @@ -430,6 +438,29 @@ template std::optional ExtractCoarrayRef(const A &x) { } } +struct ExtractSubstringHelper { + template static std::optional visit(T &&) { + return std::nullopt; + } + + static std::optional visit(const Substring &e) { return e; } + + template + static std::optional visit(const Designator &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } + + template + static std::optional visit(const Expr &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } +}; + +template +std::optional ExtractSubstring(const A &x) { + return ExtractSubstringHelper::visit(x); +} + // If an expression is simply a whole symbol data designator, // extract and return that symbol, else null. template const Symbol *UnwrapWholeSymbolDataRef(const A &x) { diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h index 8d560db34e05b..2fa90572bc63e 100644 --- a/flang/lib/Lower/DirectivesCommon.h +++ b/flang/lib/Lower/DirectivesCommon.h @@ -808,6 +808,75 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, return bounds; } +namespace detail { +template // +static T &&AsRvalueRef(T &&t) { + return std::move(t); +} +template // +static T AsRvalueRef(T &t) { + return t; +} +template // +static T AsRvalueRef(const T &t) { + return t; +} + +// Helper class for stripping enclosing parentheses and a conversion that +// preserves type category. This is used for triplet elements, which are +// always of type integer(kind=8). The lower/upper bounds are converted to +// an "index" type, which is 64-bit, so the explicit conversion to kind=8 +// (if present) is not needed. When it's present, though, it causes generated +// names to contain "int(..., kind=8)". +struct PeelConvert { + template + static Fortran::semantics::MaybeExpr visit_with_category( + const Fortran::evaluate::Expr> + &expr) { + return std::visit( + [](auto &&s) { return visit_with_category(s); }, + expr.u); + } + template + static Fortran::semantics::MaybeExpr visit_with_category( + const Fortran::evaluate::Convert, + Category> &expr) { + return AsGenericExpr(AsRvalueRef(expr.left())); + } + template + static Fortran::semantics::MaybeExpr visit_with_category(const T &) { + return std::nullopt; // + } + template + static Fortran::semantics::MaybeExpr visit_with_category(const T &) { + return std::nullopt; // + } + + template + static Fortran::semantics::MaybeExpr + visit(const Fortran::evaluate::Expr> + &expr) { + return std::visit([](auto &&s) { return visit_with_category(s); }, + expr.u); + } + static Fortran::semantics::MaybeExpr + visit(const Fortran::evaluate::Expr &expr) { + return std::visit([](auto &&s) { return visit(s); }, expr.u); + } + template // + static Fortran::semantics::MaybeExpr visit(const T &) { + return std::nullopt; + } +}; + +static Fortran::semantics::SomeExpr +peelOuterConvert(Fortran::semantics::SomeExpr &expr) { + if (auto peeled = PeelConvert::visit(expr)) + return *peeled; + return expr; +} +} // namespace detail + /// Generate bounds operations for an array section when subscripts are /// provided. template @@ -815,7 +884,7 @@ llvm::SmallVector genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx, - const std::list &subscripts, + const std::vector &subscripts, std::stringstream &asFortran, fir::ExtendedValue &dataExv, bool dataExvIsAssumedSize, AddrAndBoundsInfo &info, bool treatIndexAsSection = false) { @@ -828,8 +897,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); const int dataExvRank = static_cast(dataExv.rank()); for (const auto &subscript : subscripts) { - const auto *triplet{ - std::get_if(&subscript.u)}; + const auto *triplet{std::get_if(&subscript.u)}; if (triplet || treatIndexAsSection) { if (dimension != 0) asFortran << ','; @@ -868,13 +936,18 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, strideInBytes = true; } - const Fortran::lower::SomeExpr *lower{nullptr}; + Fortran::semantics::MaybeExpr lower; if (triplet) { - if (const auto &tripletLb{std::get<0>(triplet->t)}) - lower = Fortran::semantics::GetExpr(*tripletLb); + if ((lower = Fortran::evaluate::AsGenericExpr(triplet->lower()))) + lower = detail::peelOuterConvert(*lower); } else { - const auto &index{std::get(subscript.u)}; - lower = Fortran::semantics::GetExpr(index); + // Case of IndirectSubscriptIntegerExpr + using IndirectSubscriptIntegerExpr = + Fortran::evaluate::IndirectSubscriptIntegerExpr; + using SubscriptInteger = Fortran::evaluate::SubscriptInteger; + Fortran::evaluate::Expr oneInt = + std::get(subscript.u).value(); + lower = Fortran::evaluate::AsGenericExpr(std::move(oneInt)); if (lower->Rank() > 0) { mlir::emitError( loc, "vector subscript cannot be used for an array section"); @@ -912,10 +985,12 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, extent = one; } else { asFortran << ':'; - const auto &upper{std::get<1>(triplet->t)}; + Fortran::semantics::MaybeExpr upper = + Fortran::evaluate::AsGenericExpr(triplet->upper()); if (upper) { - uval = Fortran::semantics::GetIntValue(upper); + upper = detail::peelOuterConvert(*upper); + uval = Fortran::evaluate::ToInt64(*upper); if (uval) { if (defaultLb) { ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1); @@ -925,22 +1000,21 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, } asFortran << *uval; } else { - const Fortran::lower::SomeExpr *uexpr = - Fortran::semantics::GetExpr(*upper); mlir::Value ub = - fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx)); + fir::getBase(converter.genExprValue(loc, *upper, stmtCtx)); ub = builder.createConvert(loc, baseLb.getType(), ub); ubound = builder.create(loc, ub, baseLb); - asFortran << uexpr->AsFortran(); + asFortran << upper->AsFortran(); } } if (lower && upper) { if (lval && uval && *uval < *lval) { mlir::emitError(loc, "zero sized array section"); break; - } else if (std::get<2>(triplet->t)) { - const auto &strideExpr{std::get<2>(triplet->t)}; - if (strideExpr) { + } else { + // Stride is mandatory in evaluate::Triplet. Make sure it's 1. + auto val = Fortran::evaluate::ToInt64(triplet->GetStride()); + if (!val || *val != 1) { mlir::emitError(loc, "stride cannot be specified on " "an array section"); break; @@ -993,150 +1067,157 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, return bounds; } -template +namespace detail { +template // +std::optional getRef(Expr &&expr) { + if constexpr (std::is_same_v, + Fortran::evaluate::DataRef>) { + if (auto *ref = std::get_if(&expr.u)) + return *ref; + return std::nullopt; + } else { + auto maybeRef = Fortran::evaluate::ExtractDataRef(expr); + if (!maybeRef || !std::holds_alternative(maybeRef->u)) + return std::nullopt; + return std::get(maybeRef->u); + } +} +} // namespace detail + +template AddrAndBoundsInfo gatherDataOperandAddrAndBounds( Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, const ObjectType &object, + semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::semantics::SymbolRef symbol, + const Fortran::semantics::MaybeExpr &maybeDesignator, mlir::Location operandLocation, std::stringstream &asFortran, llvm::SmallVector &bounds, bool treatIndexAsSection = false) { + using namespace Fortran; + AddrAndBoundsInfo info; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext, - designator)}) { - if (((*expr).Rank() > 0 || treatIndexAsSection) && - Fortran::parser::Unwrap( - designator)) { - const auto *arrayElement = - Fortran::parser::Unwrap( - designator); - const auto *dataRef = - std::get_if(&designator.u); - fir::ExtendedValue dataExv; - bool dataExvIsAssumedSize = false; - if (Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>( - arrayElement->base)) { - auto exprBase = Fortran::semantics::AnalyzeExpr( - semanticsContext, arrayElement->base); - dataExv = converter.genExprAddr(operandLocation, *exprBase, - stmtCtx); - info.addr = fir::getBase(dataExv); - info.rawInput = info.addr; - asFortran << (*exprBase).AsFortran(); - } else { - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); - dataExvIsAssumedSize = Fortran::semantics::IsAssumedSizeArray( - name.symbol->GetUltimate()); - info = getDataOperandBaseAddr(converter, builder, - *name.symbol, operandLocation); - dataExv = converter.getSymbolExtendedValue(*name.symbol); - asFortran << name.ToString(); - } - - if (!arrayElement->subscripts.empty()) { - asFortran << '('; - bounds = genBoundsOps( - builder, operandLocation, converter, stmtCtx, - arrayElement->subscripts, asFortran, dataExv, - dataExvIsAssumedSize, info, treatIndexAsSection); - } - asFortran << ')'; - } else if (auto structComp = Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>(designator)) { - fir::ExtendedValue compExv = - converter.genExprAddr(operandLocation, *expr, stmtCtx); - info.addr = fir::getBase(compExv); - info.rawInput = info.addr; - if (fir::unwrapRefType(info.addr.getType()) - .isa()) - bounds = genBaseBoundsOps( - builder, operandLocation, converter, compExv, - /*isAssumedSize=*/false); - asFortran << (*expr).AsFortran(); - - bool isOptional = Fortran::semantics::IsOptional( - *Fortran::parser::GetLastName(*structComp).symbol); - if (isOptional) - info.isPresent = builder.create( - operandLocation, builder.getI1Type(), info.rawInput); - - if (auto loadOp = mlir::dyn_cast_or_null( - info.addr.getDefiningOp())) { - if (fir::isAllocatableType(loadOp.getType()) || - fir::isPointerType(loadOp.getType())) - info.addr = builder.create(operandLocation, - info.addr); - info.rawInput = info.addr; - } - - // If the component is an allocatable or pointer the result of - // genExprAddr will be the result of a fir.box_addr operation or - // a fir.box_addr has been inserted just before. - // Retrieve the box so we handle it like other descriptor. - if (auto boxAddrOp = mlir::dyn_cast_or_null( - info.addr.getDefiningOp())) { - info.addr = boxAddrOp.getVal(); - info.rawInput = info.addr; - bounds = genBoundsOpsFromBox( - builder, operandLocation, converter, compExv, info); - } - } else { - if (Fortran::parser::Unwrap( - designator)) { - // Single array element. - const auto *arrayElement = - Fortran::parser::Unwrap( - designator); - (void)arrayElement; - fir::ExtendedValue compExv = - converter.genExprAddr(operandLocation, *expr, stmtCtx); - info.addr = fir::getBase(compExv); - info.rawInput = info.addr; - asFortran << (*expr).AsFortran(); - } else if (const auto *dataRef{ - std::get_if( - &designator.u)}) { - // Scalar or full array. - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); - fir::ExtendedValue dataExv = - converter.getSymbolExtendedValue(*name.symbol); - info = getDataOperandBaseAddr(converter, builder, - *name.symbol, operandLocation); - if (fir::unwrapRefType(info.addr.getType()) - .isa()) { - bounds = genBoundsOpsFromBox( - builder, operandLocation, converter, dataExv, info); - } - bool dataExvIsAssumedSize = - Fortran::semantics::IsAssumedSizeArray( - name.symbol->GetUltimate()); - if (fir::unwrapRefType(info.addr.getType()) - .isa()) - bounds = genBaseBoundsOps( - builder, operandLocation, converter, dataExv, - dataExvIsAssumedSize); - asFortran << name.ToString(); - } else { // Unsupported - llvm::report_fatal_error( - "Unsupported type of OpenACC operand"); - } - } - } - }, - [&](const Fortran::parser::Name &name) { - info = getDataOperandBaseAddr(converter, builder, *name.symbol, - operandLocation); - asFortran << name.ToString(); - }}, - object.u); + + if (!maybeDesignator) { + info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation); + asFortran << symbol->name().ToString(); + return info; + } + + semantics::SomeExpr designator = *maybeDesignator; + + if ((designator.Rank() > 0 || treatIndexAsSection) && + IsArrayElement(designator)) { + auto arrayRef = detail::getRef(designator); + // This shouldn't fail after IsArrayElement(designator). + assert(arrayRef && "Expecting ArrayRef"); + + fir::ExtendedValue dataExv; + bool dataExvIsAssumedSize = false; + + auto toMaybeExpr = [&](auto &&base) { + using BaseType = llvm::remove_cvref_t; + evaluate::ExpressionAnalyzer ea{semaCtx}; + + if constexpr (std::is_same_v) { + if (auto *ref = base.UnwrapSymbolRef()) + return ea.Designate(evaluate::DataRef{*ref}); + if (auto *ref = base.UnwrapComponent()) + return ea.Designate(evaluate::DataRef{*ref}); + llvm_unreachable("Unexpected NamedEntity"); + } else { + static_assert(std::is_same_v); + return ea.Designate(evaluate::DataRef{base}); + } + }; + + auto arrayBase = toMaybeExpr(arrayRef->base()); + assert(arrayBase); + + if (detail::getRef(*arrayBase)) { + dataExv = converter.genExprAddr(operandLocation, *arrayBase, stmtCtx); + info.addr = fir::getBase(dataExv); + info.rawInput = info.addr; + asFortran << arrayBase->AsFortran(); + } else { + const semantics::Symbol &sym = arrayRef->GetLastSymbol(); + dataExvIsAssumedSize = + Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate()); + info = getDataOperandBaseAddr(converter, builder, sym, operandLocation); + dataExv = converter.getSymbolExtendedValue(sym); + asFortran << sym.name().ToString(); + } + + if (!arrayRef->subscript().empty()) { + asFortran << '('; + bounds = genBoundsOps( + builder, operandLocation, converter, stmtCtx, arrayRef->subscript(), + asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection); + } + asFortran << ')'; + } else if (auto compRef = detail::getRef(designator)) { + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, designator, stmtCtx); + info.addr = fir::getBase(compExv); + info.rawInput = info.addr; + if (fir::unwrapRefType(info.addr.getType()).isa()) + bounds = genBaseBoundsOps(builder, operandLocation, + converter, compExv, + /*isAssumedSize=*/false); + asFortran << designator.AsFortran(); + + if (semantics::IsOptional(compRef->GetLastSymbol())) { + info.isPresent = builder.create( + operandLocation, builder.getI1Type(), info.rawInput); + } + + if (auto loadOp = + mlir::dyn_cast_or_null(info.addr.getDefiningOp())) { + if (fir::isAllocatableType(loadOp.getType()) || + fir::isPointerType(loadOp.getType())) + info.addr = builder.create(operandLocation, info.addr); + info.rawInput = info.addr; + } + + // If the component is an allocatable or pointer the result of + // genExprAddr will be the result of a fir.box_addr operation or + // a fir.box_addr has been inserted just before. + // Retrieve the box so we handle it like other descriptor. + if (auto boxAddrOp = + mlir::dyn_cast_or_null(info.addr.getDefiningOp())) { + info.addr = boxAddrOp.getVal(); + info.rawInput = info.addr; + bounds = genBoundsOpsFromBox( + builder, operandLocation, converter, compExv, info); + } + } else { + if (detail::getRef(designator)) { + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, designator, stmtCtx); + info.addr = fir::getBase(compExv); + info.rawInput = info.addr; + asFortran << designator.AsFortran(); + } else if (auto symRef = detail::getRef(designator)) { + // Scalar or full array. + fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef); + info = + getDataOperandBaseAddr(converter, builder, *symRef, operandLocation); + if (fir::unwrapRefType(info.addr.getType()).isa()) { + bounds = genBoundsOpsFromBox( + builder, operandLocation, converter, dataExv, info); + } + bool dataExvIsAssumedSize = + Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate()); + if (fir::unwrapRefType(info.addr.getType()).isa()) + bounds = genBaseBoundsOps( + builder, operandLocation, converter, dataExv, dataExvIsAssumedSize); + asFortran << symRef->get().name().ToString(); + } else { // Unsupported + llvm::report_fatal_error("Unsupported type of OpenACC operand"); + } + } + return info; } - } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 6ae270f63f5cf..a444682306ac2 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { Fortran::parser::GetLastName(arrayElement->base); return *name.symbol; } + if (const auto *component = + Fortran::parser::Unwrap( + *designator)) { + return *component->component.symbol; + } } else if (const auto *name = std::get_if(&accObject.u)) { return *name->symbol; @@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, mlir::acc::DataClause dataClause, bool structured, bool implicit, bool setDeclareAttr = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds, - /*treatIndexAsSection=*/true); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds, + /*treatIndexAsSection=*/true); // If the input value is optional and is not a descriptor, we use the // rawInput directly. @@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations( llvm::SmallVectorImpl &dataOperands, mlir::acc::DataClause dataClause, bool structured, bool implicit) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); EntryOp op = createDataEntryOp( builder, operandLocation, info.addr, asFortran, bounds, structured, implicit, dataClause, info.addr.getType()); @@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations( if (mlir::isa(fir::unwrapRefType(info.addr.getType()))) { mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); modBuilder.setInsertionPointAfter(builder.getFunction()); - std::string prefix = - converter.mangleName(getSymbolFromAccObject(accObject)); + std::string prefix = converter.mangleName(symbol); createDeclareAllocFuncWithArg( modBuilder, builder, operandLocation, info.addr.getType(), prefix, asFortran, dataClause); @@ -783,16 +793,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList, llvm::SmallVectorImpl &dataOperands, llvm::SmallVector &privatizations) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); RecipeOp recipe; mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType()); if constexpr (std::is_same_v) { @@ -1361,16 +1374,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, const auto &op = std::get(objectList.t); mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objects.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType()); if (auto seqTy = mlir::dyn_cast(reductionTy)) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 06850bebd7d05..2df6d0560e3c7 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/CommandLine.h" @@ -48,6 +49,29 @@ using DeclareTargetCapturePair = // Common helper functions //===----------------------------------------------------------------------===// +static llvm::ArrayRef getWorksharing() { + static llvm::omp::Directive worksharing[] = { + llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_for, + llvm::omp::Directive::OMPD_scope, llvm::omp::Directive::OMPD_sections, + llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare, + }; + return worksharing; +} + +static llvm::ArrayRef getWorksharingLoop() { + static llvm::omp::Directive worksharingLoop[] = { + llvm::omp::Directive::OMPD_do, + llvm::omp::Directive::OMPD_for, + }; + return worksharingLoop; +} + +static uint32_t getOpenMPVersion(const mlir::ModuleOp &mod) { + if (mlir::Attribute verAttr = mod->getAttr("omp.version")) + return llvm::cast(verAttr).getVersion(); + llvm_unreachable("Expecting OpenMP version attribute in module"); +} + static Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { Fortran::semantics::Symbol *sym = nullptr; @@ -72,74 +96,2191 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { return sym; } -static void genObjectList(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands) { - auto addOperands = [&](Fortran::lower::SymbolRef sym) { - const mlir::Value variable = converter.getSymbolAddress(sym); - if (variable) { - operands.push_back(variable); +static void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands) { + auto addOperands = [&](Fortran::lower::SymbolRef sym) { + const mlir::Value variable = converter.getSymbolAddress(sym); + if (variable) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), sym); + } + } + }; + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + addOperands(*sym); + } +} + +static Fortran::lower::pft::Evaluation * +getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) { + // Return the Evaluation of the innermost collapsed loop, or the current one + // if there was no COLLAPSE. + if (collapseValue == 0) + return &eval; + + Fortran::lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation(); + for (int i = 1; i < collapseValue; i++) { + // The nested evaluations should be DoConstructs (i.e. they should form + // a loop nest). Each DoConstruct is a tuple . + assert(curEval->isA()); + curEval = &*std::next(curEval->getNestedEvaluations().begin()); + } + return curEval; +} + +static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + int collapseValue = 0) { + Fortran::lower::pft::Evaluation *curEval = + getCollapsedLoopEval(eval, collapseValue); + + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + converter.genEval(e); +} + +//===----------------------------------------------------------------------===// +// Clauses +//===----------------------------------------------------------------------===// + +namespace detail { +template // +llvm::omp::Clause getClauseIdForClass(C &&) { + using namespace Fortran; + using A = llvm::remove_cvref_t; // A is referenced in OMP.inc + // The code included below contains a sequence of checks like the following + // for each OpenMP clause + // if constexpr (std::is_same_v) + // return llvm::omp::Clause::OMPC_acq_rel; + // [...] +#define GEN_FLANG_CLAUSE_PARSER_KIND_MAP +#include "llvm/Frontend/OpenMP/OMP.inc" +} +} // namespace detail + +static llvm::omp::Clause getClauseId(const Fortran::parser::OmpClause &clause) { + return std::visit([](auto &&s) { return detail::getClauseIdForClass(s); }, + clause.u); +} + +namespace omp { +using namespace Fortran; +using SomeType = evaluate::SomeType; +using SomeExpr = semantics::SomeExpr; +using MaybeExpr = semantics::MaybeExpr; + +template // +using List = std::vector; + +struct SymDsgExtractor { + using SymDsg = std::tuple; + + template // + static T &&AsRvalueRef(T &&t) { + return std::move(t); + } + template // + static T AsRvalueRef(const T &t) { + return t; + } + + static semantics::Symbol *symbol_addr(const evaluate::SymbolRef &ref) { + // Symbols cannot be created after semantic checks, so all symbol + // pointers that are non-null must point to one of those pre-existing + // objects. Throughout the code, symbols are often pointed to by + // non-const pointers, so there is no harm in casting the constness + // away. + return const_cast(&ref.get()); + } + + template // + static SymDsg visit(T &&) { + // Use this to see missing overloads: + // llvm::errs() << "NULL: " << __PRETTY_FUNCTION__ << '\n'; + return SymDsg{}; + } + + template // + static SymDsg visit(const evaluate::Designator &e) { + return std::make_tuple(symbol_addr(*e.GetLastSymbol()), + evaluate::AsGenericExpr(AsRvalueRef(e))); + } + + static SymDsg visit(const evaluate::ProcedureDesignator &e) { + return std::make_tuple(symbol_addr(*e.GetSymbol()), std::nullopt); + } + + template // + static SymDsg visit(const evaluate::Expr &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } + + static bool verify(const SymDsg &sd) { + const semantics::Symbol *symbol = std::get<0>(sd); + assert(symbol && "Expecting Symbol"); + auto &maybeDsg = std::get<1>(sd); + if (!maybeDsg) + return true; + std::optional maybeRef = + evaluate::ExtractDataRef(*maybeDsg); + if (maybeRef) { + assert(&maybeRef->GetLastSymbol() == symbol && + "Designator not for symbol"); + return true; + } + + // This could still be a Substring or ComplexPart, but at least Substring + // is not allowed in OpenMP. + maybeDsg->dump(); + llvm_unreachable("Expecting DataRef"); + } +}; + +SymDsgExtractor::SymDsg getSymbolAndDesignator(const MaybeExpr &expr) { + if (!expr) + return SymDsgExtractor::SymDsg{}; + return std::visit([](auto &&s) { return SymDsgExtractor::visit(s); }, + expr->u); +} + +struct Object { + semantics::Symbol *sym; // symbol + MaybeExpr dsg; // designator ending with symbol +}; + +using ObjectList = List; + +Object makeObject(const parser::OmpObject &object, + semantics::SemanticsContext &semaCtx) { + // If object is a common block, expression analyzer won't be able to + // do anything. + if (const auto *name = std::get_if(&object.u)) { + assert(name->symbol && "Expecting Symbol"); + return Object{name->symbol, std::nullopt}; + } + evaluate::ExpressionAnalyzer ea{semaCtx}; + SymDsgExtractor::SymDsg sd = std::visit( + [&](auto &&s) { return getSymbolAndDesignator(ea.Analyze(s)); }, + object.u); + SymDsgExtractor::verify(sd); + return Object{std::get<0>(sd), std::move(std::get<1>(sd))}; +} + +Object makeObject(const parser::Name &name, + semantics::SemanticsContext &semaCtx) { + assert(name.symbol && "Expecting Symbol"); + return Object{name.symbol, std::nullopt}; +} + +Object makeObject(const parser::Designator &dsg, + semantics::SemanticsContext &semaCtx) { + evaluate::ExpressionAnalyzer ea{semaCtx}; + SymDsgExtractor::SymDsg sd = getSymbolAndDesignator(ea.Analyze(dsg)); + SymDsgExtractor::verify(sd); + return Object{std::get<0>(sd), std::move(std::get<1>(sd))}; +} + +Object makeObject(const parser::StructureComponent &comp, + semantics::SemanticsContext &semaCtx) { + evaluate::ExpressionAnalyzer ea{semaCtx}; + SymDsgExtractor::SymDsg sd = getSymbolAndDesignator(ea.Analyze(comp)); + SymDsgExtractor::verify(sd); + return Object{std::get<0>(sd), std::move(std::get<1>(sd))}; +} + +auto makeObject(semantics::SemanticsContext &semaCtx) { + return [&](auto &&s) { return makeObject(s, semaCtx); }; +} + +template +SomeExpr makeExpr(T &&inp, semantics::SemanticsContext &semaCtx) { + auto maybeExpr = evaluate::ExpressionAnalyzer(semaCtx).Analyze(inp); + assert(maybeExpr); + return std::move(*maybeExpr); +} + +auto makeExpr(semantics::SemanticsContext &semaCtx) { + return [&](auto &&s) { return makeExpr(s, semaCtx); }; +} + +template ::value_type, + typename R = std::invoke_result_t> +List makeList(C &&container, F &&func) { + List v; + llvm::transform(container, std::back_inserter(v), func); + return v; +} + +ObjectList makeList(const parser::OmpObjectList &objects, + semantics::SemanticsContext &semaCtx) { + return makeList(objects.v, makeObject(semaCtx)); +} + +template // +U enum_cast(T t) { + using BareT = llvm::remove_cvref_t; + using BareU = llvm::remove_cvref_t; + static_assert(std::is_enum_v && std::is_enum_v); + + return U{static_cast>(t)}; +} + +template > +std::optional maybeApply(F &&func, const std::optional &inp) { + if (!inp) + return std::nullopt; + return std::move(func(*inp)); +} + +std::optional +getBaseObject(const Object &object, + Fortran::semantics::SemanticsContext &semaCtx) { + // If it's just the symbol, then there is no base. + if (!object.dsg) + return std::nullopt; + + auto maybeRef = evaluate::ExtractDataRef(*object.dsg); + if (!maybeRef) + return std::nullopt; + + evaluate::DataRef ref = *maybeRef; + + if (std::get_if(&ref.u)) { + return std::nullopt; + } else if (auto *comp = std::get_if(&ref.u)) { + const evaluate::DataRef &base = comp->base(); + return Object{SymDsgExtractor::symbol_addr(base.GetLastSymbol()), + evaluate::AsGenericExpr(SymDsgExtractor::AsRvalueRef(base))}; + } else if (auto *arr = std::get_if(&ref.u)) { + const evaluate::NamedEntity &base = arr->base(); + evaluate::ExpressionAnalyzer ea{semaCtx}; + if (auto *comp = base.UnwrapComponent()) { + return Object{ + SymDsgExtractor::symbol_addr(comp->symbol()), + ea.Designate(evaluate::DataRef{SymDsgExtractor::AsRvalueRef(*comp)})}; + } else if (base.UnwrapSymbolRef()) { + return std::nullopt; + } + } else { + assert(std::holds_alternative(ref.u)); + llvm_unreachable("Coarray reference not supported at the moment"); + } + return std::nullopt; +} + +namespace clause { +#ifdef EMPTY_CLASS +#undef EMPTY_CLASS +#endif +#define EMPTY_CLASS(cls) \ + struct cls { \ + using EmptyTrait = std::true_type; \ + }; \ + cls make(const parser::OmpClause::cls &, semantics::SemanticsContext &) { \ + return cls{}; \ + } + +#ifdef WRAPPER_CLASS +#undef WRAPPER_CLASS +#endif +#define WRAPPER_CLASS(cls, content) // Nothing +#define GEN_FLANG_CLAUSE_PARSER_CLASSES +#include "llvm/Frontend/OpenMP/OMP.inc" +#undef EMPTY_CLASS + +// Helper objects + +struct DefinedOperator { + struct DefinedOpName { + using WrapperTrait = std::true_type; + Object v; + }; + ENUM_CLASS(IntrinsicOperator, Power, Multiply, Divide, Add, Subtract, Concat, + LT, LE, EQ, NE, GE, GT, NOT, AND, OR, EQV, NEQV) + using UnionTrait = std::true_type; + std::variant u; +}; + +DefinedOperator makeDefOp(const parser::DefinedOperator &inp, + semantics::SemanticsContext &semaCtx) { + return DefinedOperator{ + std::visit(common::visitors{ + [&](const parser::DefinedOpName &s) { + return DefinedOperator{DefinedOperator::DefinedOpName{ + makeObject(s.v, semaCtx)}}; + }, + [&](const parser::DefinedOperator::IntrinsicOperator &s) { + return DefinedOperator{ + enum_cast(s)}; + }, + }, + inp.u), + }; +} + +struct ProcedureDesignator { + using WrapperTrait = std::true_type; + Object v; +}; + +ProcedureDesignator makeProcDsg(const parser::ProcedureDesignator &inp, + semantics::SemanticsContext &semaCtx) { + return ProcedureDesignator{std::visit( + common::visitors{ + [&](const parser::Name &t) { return makeObject(t, semaCtx); }, + [&](const parser::ProcComponentRef &t) { + return makeObject(t.v.thing, semaCtx); + }, + }, + inp.u)}; +} + +struct ReductionOperator { + using UnionTrait = std::true_type; + std::variant u; +}; + +ReductionOperator makeRedOp(const parser::OmpReductionOperator &inp, + semantics::SemanticsContext &semaCtx) { + return std::visit(common::visitors{ + [&](const parser::DefinedOperator &s) { + return ReductionOperator{makeDefOp(s, semaCtx)}; + }, + [&](const parser::ProcedureDesignator &s) { + return ReductionOperator{makeProcDsg(s, semaCtx)}; + }, + }, + inp.u); +} + +// Actual clauses. Each T (where OmpClause::T exists) has its "make". + +struct Aligned { + using TupleTrait = std::true_type; + std::tuple t; +}; + +Aligned make(const parser::OmpClause::Aligned &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpAlignedClause + auto &t0 = std::get(inp.v.t); + auto &t1 = std::get>(inp.v.t); + + return Aligned{{ + makeList(t0, semaCtx), + maybeApply(makeExpr(semaCtx), t1), + }}; +} + +struct Allocate { + struct Modifier { + struct Allocator { + using WrapperTrait = std::true_type; + SomeExpr v; + }; + struct Align { + using WrapperTrait = std::true_type; + SomeExpr v; + }; + struct ComplexModifier { + using TupleTrait = std::true_type; + std::tuple t; + }; + using UnionTrait = std::true_type; + std::variant u; + }; + using TupleTrait = std::true_type; + std::tuple, ObjectList> t; +}; + +Allocate make(const parser::OmpClause::Allocate &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpAllocateClause + using wrapped = parser::OmpAllocateClause; + auto &t0 = std::get>(inp.v.t); + auto &t1 = std::get(inp.v.t); + + auto convert = [&](auto &&s) -> Allocate::Modifier { + using Modifier = Allocate::Modifier; + using Allocator = Modifier::Allocator; + using Align = Modifier::Align; + using ComplexModifier = Modifier::ComplexModifier; + + return Modifier{ + std::visit( + common::visitors{ + [&](const wrapped::AllocateModifier::Allocator &v) { + return Modifier{Allocator{makeExpr(v.v, semaCtx)}}; + }, + [&](const wrapped::AllocateModifier::ComplexModifier &v) { + auto &s0 = + std::get(v.t); + auto &s1 = std::get(v.t); + return Modifier{ComplexModifier{{ + Allocator{makeExpr(s0.v, semaCtx)}, + Align{makeExpr(s1.v, semaCtx)}, + }}}; + }, + [&](const wrapped::AllocateModifier::Align &v) { + return Modifier{Align{makeExpr(v.v, semaCtx)}}; + }, + }, + s.u), + }; + }; + + return Allocate{{maybeApply(convert, t0), makeList(t1, semaCtx)}}; +} + +struct Allocator { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Allocator make(const parser::OmpClause::Allocator &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return Allocator{makeExpr(inp.v, semaCtx)}; +} + +struct AtomicDefaultMemOrder { + using WrapperTrait = std::true_type; + common::OmpAtomicDefaultMemOrderType v; +}; + +AtomicDefaultMemOrder make(const parser::OmpClause::AtomicDefaultMemOrder &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpAtomicDefaultMemOrderClause + return AtomicDefaultMemOrder{inp.v.v}; +} + +struct Collapse { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Collapse make(const parser::OmpClause::Collapse &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntConstantExpr + return Collapse{makeExpr(inp.v, semaCtx)}; +} + +struct Copyin { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Copyin make(const parser::OmpClause::Copyin &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Copyin{makeList(inp.v, semaCtx)}; +} + +struct Copyprivate { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Copyprivate make(const parser::OmpClause::Copyprivate &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Copyprivate{makeList(inp.v, semaCtx)}; +} + +struct Defaultmap { + ENUM_CLASS(ImplicitBehavior, Alloc, To, From, Tofrom, Firstprivate, None, + Default) + ENUM_CLASS(VariableCategory, Scalar, Aggregate, Allocatable, Pointer) + using TupleTrait = std::true_type; + std::tuple> t; +}; + +Defaultmap make(const parser::OmpClause::Defaultmap &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpDefaultmapClause + using wrapped = parser::OmpDefaultmapClause; + + auto convert = [](auto &&s) -> Defaultmap::VariableCategory { + return enum_cast(s); + }; + auto &t0 = std::get(inp.v.t); + auto &t1 = std::get>(inp.v.t); + auto v0 = enum_cast(t0); + return Defaultmap{{v0, maybeApply(convert, t1)}}; +} + +struct Default { + ENUM_CLASS(Type, Private, Firstprivate, Shared, None) + using WrapperTrait = std::true_type; + Type v; +}; + +Default make(const parser::OmpClause::Default &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpDefaultClause + return Default{enum_cast(inp.v.v)}; +} + +struct Depend { + struct Source { + using EmptyTrait = std::true_type; + }; + struct Sink { + using Length = std::tuple; + using Vec = std::tuple>; + using WrapperTrait = std::true_type; + List v; + }; + ENUM_CLASS(Type, In, Out, Inout, Source, Sink) + struct InOut { + using TupleTrait = std::true_type; + std::tuple t; + }; + using UnionTrait = std::true_type; + std::variant u; +}; + +Depend make(const parser::OmpClause::Depend &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpDependClause + using wrapped = parser::OmpDependClause; + + return std::visit( + common::visitors{ + [&](const wrapped::Source &s) { return Depend{Depend::Source{}}; }, + [&](const wrapped::Sink &s) { + auto convert = [&](const parser::OmpDependSinkVec &v) { + auto &t0 = std::get(v.t); + auto &t1 = + std::get>(v.t); + auto convert1 = [&](const parser::OmpDependSinkVecLength &u) { + auto &s0 = std::get(u.t); + auto &s1 = std::get(u.t); + return Depend::Sink::Length{makeDefOp(s0, semaCtx), + makeExpr(s1, semaCtx)}; + }; + return Depend::Sink::Vec{makeObject(t0, semaCtx), + maybeApply(convert1, t1)}; + }; + return Depend{Depend::Sink{makeList(s.v, convert)}}; + }, + [&](const wrapped::InOut &s) { + auto &t0 = std::get(s.t); + auto &t1 = std::get>(s.t); + auto convert = [&](const parser::Designator &t) { + return makeObject(t, semaCtx); + }; + return Depend{Depend::InOut{ + {enum_cast(t0.v), makeList(t1, convert)}}}; + }, + }, + inp.v.u); +} + +struct Device { + ENUM_CLASS(DeviceModifier, Ancestor, Device_Num) + using TupleTrait = std::true_type; + std::tuple, SomeExpr> t; +}; + +Device make(const parser::OmpClause::Device &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpDeviceClause + using wrapped = parser::OmpDeviceClause; + + auto convert = [](auto &&s) -> Device::DeviceModifier { + return enum_cast(s); + }; + auto &t0 = std::get>(inp.v.t); + auto &t1 = std::get(inp.v.t); + return Device{{maybeApply(convert, t0), makeExpr(t1, semaCtx)}}; +} + +struct DeviceType { + ENUM_CLASS(Type, Any, Host, Nohost) + using WrapperTrait = std::true_type; + Type v; +}; + +DeviceType make(const parser::OmpClause::DeviceType &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpDeviceTypeClause + return DeviceType{enum_cast(inp.v.v)}; +} + +struct DistSchedule { + using WrapperTrait = std::true_type; + MaybeExpr v; +}; + +DistSchedule make(const parser::OmpClause::DistSchedule &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional + return DistSchedule{maybeApply(makeExpr(semaCtx), inp.v)}; +} + +struct Enter { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Enter make(const parser::OmpClause::Enter &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Enter{makeList(inp.v, semaCtx)}; +} + +struct Filter { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Filter make(const parser::OmpClause::Filter &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return Filter{makeExpr(inp.v, semaCtx)}; +} + +struct Final { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Final make(const parser::OmpClause::Final &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarLogicalExpr + return Final{makeExpr(inp.v, semaCtx)}; +} + +struct Firstprivate { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Firstprivate make(const parser::OmpClause::Firstprivate &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Firstprivate{makeList(inp.v, semaCtx)}; +} + +struct From { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +From make(const parser::OmpClause::From &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return From{makeList(inp.v, semaCtx)}; +} + +struct Grainsize { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Grainsize make(const parser::OmpClause::Grainsize &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return Grainsize{makeExpr(inp.v, semaCtx)}; +} + +struct HasDeviceAddr { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +HasDeviceAddr make(const parser::OmpClause::HasDeviceAddr &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return HasDeviceAddr{makeList(inp.v, semaCtx)}; +} + +struct Hint { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Hint make(const parser::OmpClause::Hint &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ConstantExpr + return Hint{makeExpr(inp.v, semaCtx)}; +} + +struct If { + ENUM_CLASS(DirectiveNameModifier, Parallel, Simd, Target, TargetData, + TargetEnterData, TargetExitData, TargetUpdate, Task, Taskloop, + Teams) + using TupleTrait = std::true_type; + std::tuple, SomeExpr> t; +}; + +If make(const parser::OmpClause::If &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpIfClause + using wrapped = parser::OmpIfClause; + + auto &t0 = std::get>(inp.v.t); + auto &t1 = std::get(inp.v.t); + auto convert = [](auto &&s) -> If::DirectiveNameModifier { + return enum_cast(s); + }; + return If{{maybeApply(convert, t0), makeExpr(t1, semaCtx)}}; +} + +struct InReduction { + using TupleTrait = std::true_type; + std::tuple t; +}; + +InReduction make(const parser::OmpClause::InReduction &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpInReductionClause + auto &t0 = std::get(inp.v.t); + auto &t1 = std::get(inp.v.t); + return InReduction{{makeRedOp(t0, semaCtx), makeList(t1, semaCtx)}}; +} + +struct IsDevicePtr { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +IsDevicePtr make(const parser::OmpClause::IsDevicePtr &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return IsDevicePtr{makeList(inp.v, semaCtx)}; +} + +struct Lastprivate { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Lastprivate make(const parser::OmpClause::Lastprivate &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Lastprivate{makeList(inp.v, semaCtx)}; +} + +struct Linear { + struct Modifier { + ENUM_CLASS(Type, Ref, Val, Uval) + using WrapperTrait = std::true_type; + Type v; + }; + using TupleTrait = std::true_type; + std::tuple, ObjectList, MaybeExpr> t; +}; + +Linear make(const parser::OmpClause::Linear &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpLinearClause + using wrapped = parser::OmpLinearClause; + + return std::visit( + common::visitors{ + [&](const wrapped::WithModifier &s) { + auto v = enum_cast(s.modifier.v); + return Linear{{Linear::Modifier{v}, + makeList(s.names, makeObject(semaCtx)), + maybeApply(makeExpr(semaCtx), s.step)}}; + }, + [&](const wrapped::WithoutModifier &s) { + return Linear{{std::nullopt, makeList(s.names, makeObject(semaCtx)), + maybeApply(makeExpr(semaCtx), s.step)}}; + }, + }, + inp.v.u); +} + +struct Link { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Link make(const parser::OmpClause::Link &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Link{makeList(inp.v, semaCtx)}; +} + +struct Map { + struct MapType { + struct Always { + using EmptyTrait = std::true_type; + }; + ENUM_CLASS(Type, To, From, Tofrom, Alloc, Release, Delete) + using TupleTrait = std::true_type; + std::tuple, Type> t; + }; + using TupleTrait = std::true_type; + std::tuple, ObjectList> t; +}; + +Map make(const parser::OmpClause::Map &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpMapClause + auto &t0 = std::get>(inp.v.t); + auto &t1 = std::get(inp.v.t); + auto convert = [](const parser::OmpMapType &s) { + auto &s0 = std::get>(s.t); + auto &s1 = std::get(s.t); + auto convertT = [](parser::OmpMapType::Always) { + return Map::MapType::Always{}; + }; + return Map::MapType{ + {maybeApply(convertT, s0), enum_cast(s1)}}; + }; + return Map{{maybeApply(convert, t0), makeList(t1, semaCtx)}}; +} + +struct Nocontext { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Nocontext make(const parser::OmpClause::Nocontext &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarLogicalExpr + return Nocontext{makeExpr(inp.v, semaCtx)}; +} + +struct Nontemporal { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Nontemporal make(const parser::OmpClause::Nontemporal &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::list + return Nontemporal{makeList(inp.v, makeObject(semaCtx))}; +} + +struct Novariants { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Novariants make(const parser::OmpClause::Novariants &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarLogicalExpr + return Novariants{makeExpr(inp.v, semaCtx)}; +} + +struct NumTasks { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +NumTasks make(const parser::OmpClause::NumTasks &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return NumTasks{makeExpr(inp.v, semaCtx)}; +} + +struct NumTeams { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +NumTeams make(const parser::OmpClause::NumTeams &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return NumTeams{makeExpr(inp.v, semaCtx)}; +} + +struct NumThreads { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +NumThreads make(const parser::OmpClause::NumThreads &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return NumThreads{makeExpr(inp.v, semaCtx)}; +} + +struct OmpxDynCgroupMem { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +OmpxDynCgroupMem make(const parser::OmpClause::OmpxDynCgroupMem &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return OmpxDynCgroupMem{makeExpr(inp.v, semaCtx)}; +} + +struct Ordered { + using WrapperTrait = std::true_type; + MaybeExpr v; +}; + +Ordered make(const parser::OmpClause::Ordered &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional + return Ordered{maybeApply(makeExpr(semaCtx), inp.v)}; +} + +struct Order { + ENUM_CLASS(Kind, Reproducible, Unconstrained) + ENUM_CLASS(Type, Concurrent) + using TupleTrait = std::true_type; + std::tuple, Type> t; +}; + +Order make(const parser::OmpClause::Order &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpOrderClause + using wrapped = parser::OmpOrderClause; + auto &t0 = std::get>(inp.v.t); + auto &t1 = std::get(inp.v.t); + auto convert = [](const parser::OmpOrderModifier &s) -> Order::Kind { + return enum_cast( + std::get(s.u)); + }; + return Order{{maybeApply(convert, t0), enum_cast(t1)}}; +} + +struct Partial { + using WrapperTrait = std::true_type; + MaybeExpr v; +}; + +Partial make(const parser::OmpClause::Partial &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional + return Partial{maybeApply(makeExpr(semaCtx), inp.v)}; +} + +struct Priority { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Priority make(const parser::OmpClause::Priority &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return Priority{makeExpr(inp.v, semaCtx)}; +} + +struct Private { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Private make(const parser::OmpClause::Private &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Private{makeList(inp.v, semaCtx)}; +} + +struct ProcBind { + ENUM_CLASS(Type, Close, Master, Spread, Primary) + using WrapperTrait = std::true_type; + Type v; +}; + +ProcBind make(const parser::OmpClause::ProcBind &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpProcBindClause + return ProcBind{enum_cast(inp.v.v)}; +} + +struct Reduction { + using TupleTrait = std::true_type; + std::tuple t; +}; + +Reduction make(const parser::OmpClause::Reduction &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpReductionClause + auto &t0 = std::get(inp.v.t); + auto &t1 = std::get(inp.v.t); + return Reduction{{makeRedOp(t0, semaCtx), makeList(t1, semaCtx)}}; +} + +struct Safelen { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Safelen make(const parser::OmpClause::Safelen &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntConstantExpr + return Safelen{makeExpr(inp.v, semaCtx)}; +} + +struct Schedule { + ENUM_CLASS(ModType, Monotonic, Nonmonotonic, Simd) + struct ScheduleModifier { + using TupleTrait = std::true_type; + std::tuple> t; + }; + ENUM_CLASS(ScheduleType, Static, Dynamic, Guided, Auto, Runtime) + using TupleTrait = std::true_type; + std::tuple, ScheduleType, MaybeExpr> t; +}; + +Schedule make(const parser::OmpClause::Schedule &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpScheduleClause + using wrapped = parser::OmpScheduleClause; + + auto &t0 = std::get>(inp.v.t); + auto &t1 = std::get(inp.v.t); + auto &t2 = std::get>(inp.v.t); + + auto convert = [](auto &&s) -> Schedule::ScheduleModifier { + auto &s0 = std::get(s.t); + auto &s1 = + std::get>(s.t); + + auto convert1 = [](auto &&v) { // Modifier1 or Modifier2 + return enum_cast(v.v.v); + }; + return Schedule::ScheduleModifier{{convert1(s0), maybeApply(convert1, s1)}}; + }; + + return Schedule{{maybeApply(convert, t0), + enum_cast(t1), + maybeApply(makeExpr(semaCtx), t2)}}; +} + +struct Shared { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Shared make(const parser::OmpClause::Shared &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return Shared{makeList(inp.v, semaCtx)}; +} + +struct Simdlen { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +Simdlen make(const parser::OmpClause::Simdlen &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntConstantExpr + return Simdlen{makeExpr(inp.v, semaCtx)}; +} + +struct Sizes { + using WrapperTrait = std::true_type; + List v; +}; + +Sizes make(const parser::OmpClause::Sizes &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::list + return Sizes{makeList(inp.v, makeExpr(semaCtx))}; +} + +struct TaskReduction { + using TupleTrait = std::true_type; + std::tuple t; +}; + +TaskReduction make(const parser::OmpClause::TaskReduction &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpReductionClause + auto &t0 = std::get(inp.v.t); + auto &t1 = std::get(inp.v.t); + return TaskReduction{{makeRedOp(t0, semaCtx), makeList(t1, semaCtx)}}; +} + +struct ThreadLimit { + using WrapperTrait = std::true_type; + SomeExpr v; +}; + +ThreadLimit make(const parser::OmpClause::ThreadLimit &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::ScalarIntExpr + return ThreadLimit{makeExpr(inp.v, semaCtx)}; +} + +struct To { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +To make(const parser::OmpClause::To &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return To{makeList(inp.v, semaCtx)}; +} + +struct Uniform { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +Uniform make(const parser::OmpClause::Uniform &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::list + return Uniform{makeList(inp.v, makeObject(semaCtx))}; +} + +struct UseDeviceAddr { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +UseDeviceAddr make(const parser::OmpClause::UseDeviceAddr &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return UseDeviceAddr{makeList(inp.v, semaCtx)}; +} + +struct UseDevicePtr { + using WrapperTrait = std::true_type; + ObjectList v; +}; + +UseDevicePtr make(const parser::OmpClause::UseDevicePtr &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> parser::OmpObjectList + return UseDevicePtr{makeList(inp.v, semaCtx)}; +} + +using UnionOfAllClauses = std::variant< + AcqRel, Acquire, AdjustArgs, Affinity, Align, Aligned, Allocate, Allocator, + AppendArgs, At, AtomicDefaultMemOrder, Bind, CancellationConstructType, + Capture, Collapse, Compare, Copyprivate, Copyin, Default, Defaultmap, + Depend, Depobj, Destroy, Detach, Device, DeviceType, DistSchedule, Doacross, + DynamicAllocators, Enter, Exclusive, Fail, Filter, Final, Firstprivate, + Flush, From, Full, Grainsize, HasDeviceAddr, Hint, If, InReduction, + Inbranch, Inclusive, Indirect, Init, IsDevicePtr, Lastprivate, Linear, Link, + Map, Match, MemoryOrder, Mergeable, Message, Nogroup, Nowait, Nocontext, + Nontemporal, Notinbranch, Novariants, NumTasks, NumTeams, NumThreads, + OmpxAttribute, OmpxDynCgroupMem, OmpxBare, Order, Ordered, Partial, + Priority, Private, ProcBind, Read, Reduction, Relaxed, Release, + ReverseOffload, Safelen, Schedule, SeqCst, Severity, Shared, Simd, Simdlen, + Sizes, TaskReduction, ThreadLimit, Threadprivate, Threads, To, + UnifiedAddress, UnifiedSharedMemory, Uniform, Unknown, Untied, Update, Use, + UseDeviceAddr, UseDevicePtr, UsesAllocators, Weak, When, Write>; + +} // namespace clause + +struct Clause { + parser::CharBlock source; + llvm::omp::Clause id; // The numeric id of the clause + using UnionTrait = std::true_type; + clause::UnionOfAllClauses u; +}; + +template +Clause makeClause(llvm::omp::Clause id, Specific &&specific, + parser::CharBlock source = {}) { + return Clause{source, id, specific}; +} + +Clause makeClause(const Fortran::parser::OmpClause &cls, + semantics::SemanticsContext &semaCtx) { + return std::visit( + [&](auto &&s) { + return makeClause(getClauseId(cls), clause::make(s, semaCtx), + cls.source); + }, + cls.u); +} + +List makeList(const parser::OmpClauseList &clauses, + semantics::SemanticsContext &semaCtx) { + return makeList(clauses.v, [&](const parser::OmpClause &s) { + return makeClause(s, semaCtx); + }); +} +} // namespace omp + +static void genObjectList(const omp::ObjectList &objects, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands) { + for (const omp::Object &object : objects) { + const Fortran::semantics::Symbol *sym = object.sym; + assert(sym && "Expected Symbol"); + if (mlir::Value variable = converter.getSymbolAddress(*sym)) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), *sym); + } + } + } +} + +static void gatherFuncAndVarSyms( + const omp::ObjectList &objects, + mlir::omp::DeclareTargetCaptureClause clause, + llvm::SmallVectorImpl &symbolAndClause) { + for (const omp::Object &object : objects) + symbolAndClause.emplace_back(clause, *object.sym); +} + +//===----------------------------------------------------------------------===// +// Directive decomposition +//===----------------------------------------------------------------------===// + +namespace { +struct DirectiveInfo { + llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown; + llvm::SmallVector clauses; +}; + +struct CompositeInfo { + CompositeInfo(const mlir::ModuleOp &modOp, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &ev, + llvm::omp::Directive compDir, + const Fortran::parser::OmpClauseList &clauseList); + using ClauseSet = std::set; + + bool split(); + void addClauseSymbols(const omp::Clause &clause); + + DirectiveInfo *findDirective(llvm::omp::Directive dirId) { + for (DirectiveInfo &dir : leafs) { + if (dir.id == dirId) + return &dir; + } + return nullptr; + } + ClauseSet *findClauses(const omp::Object &object) { + if (auto found = syms.find(object.sym); found != syms.end()) + return &found->second; + return nullptr; + } + + Fortran::semantics::SemanticsContext &semaCtx; + const mlir::ModuleOp &mod; + Fortran::lower::pft::Evaluation &eval; + + llvm::SmallVector leafs; // Ordered outer to inner. + omp::List clauses; + llvm::DenseMap syms; + llvm::DenseSet mapBases; + // Storage for newly created clauses. Beware of invalidating addresses. + std::list extras; + +private: + void addClauseSymsToMap(const omp::Object &object, const omp::Clause *); + void addClauseSymsToMap(const omp::ObjectList &objects, const omp::Clause *); + void addClauseSymsToMap(const omp::SomeExpr &item, const omp::Clause *); + void addClauseSymsToMap(const omp::clause::Map &item, const omp::Clause *); + + template + void addClauseSymsToMap(const std::optional &item, const omp::Clause *); + template + void addClauseSymsToMap(const omp::List &item, const omp::Clause *); + template + void addClauseSymsToMap(const std::tuple &item, const omp::Clause *, + std::index_sequence = {}); + template >, int> = 0> + void addClauseSymsToMap(T &&item, const omp::Clause *); + template < + typename T, + std::enable_if_t::EmptyTrait::value, int> = 0> + void addClauseSymsToMap(T &&item, const omp::Clause *); + template < + typename T, + std::enable_if_t::WrapperTrait::value, int> = 0> + void addClauseSymsToMap(T &&item, const omp::Clause *); + template < + typename T, + std::enable_if_t::TupleTrait::value, int> = 0> + void addClauseSymsToMap(T &&item, const omp::Clause *); + template < + typename T, + std::enable_if_t::UnionTrait::value, int> = 0> + void addClauseSymsToMap(T &&item, const omp::Clause *); + + // Apply a clause to the only directive that allows it. If there are no + // directives that allow it, or if there is more that one, do not apply + // anything and return false, otherwise return true. + bool applyToUnique(const omp::Clause *node); + + // Apply a clause to the first directive in given range that allows it. + // If such a directive does not exist, return false, otherwise return true. + template + bool applyToFirst(const omp::Clause *node, const mlir::ModuleOp &mod, + llvm::iterator_range range); + + // Apply a clause to the innermost directive that allows it. If such a + // directive does not exist, return false, otherwise return true. + bool applyToInnermost(const omp::Clause *node); + + // Apply a clause to the outermost directive that allows it. If such a + // directive does not exist, return false, otherwise return true. + bool applyToOutermost(const omp::Clause *node); + + template + bool applyIf(const omp::Clause *node, Predicate shouldApply); + + bool applyToAll(const omp::Clause *node); + + template + bool applyClause(Clause &&clause, const omp::Clause *node); + + bool applyClause(const omp::clause::Collapse &clause, const omp::Clause *); + bool applyClause(const omp::clause::Private &clause, const omp::Clause *); + bool applyClause(const omp::clause::Firstprivate &clause, + const omp::Clause *); + bool applyClause(const omp::clause::Lastprivate &clause, const omp::Clause *); + bool applyClause(const omp::clause::Shared &clause, const omp::Clause *); + bool applyClause(const omp::clause::Default &clause, const omp::Clause *); + bool applyClause(const omp::clause::ThreadLimit &clause, const omp::Clause *); + bool applyClause(const omp::clause::Order &clause, const omp::Clause *); + bool applyClause(const omp::clause::Allocate &clause, const omp::Clause *); + bool applyClause(const omp::clause::Reduction &clause, const omp::Clause *); + bool applyClause(const omp::clause::If &clause, const omp::Clause *); + bool applyClause(const omp::clause::Linear &clause, const omp::Clause *); + bool applyClause(const omp::clause::Nowait &clause, const omp::Clause *); +}; +} // namespace + +CompositeInfo::CompositeInfo(const mlir::ModuleOp &modOp, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &ev, + llvm::omp::Directive compDir, + const Fortran::parser::OmpClauseList &clauseList) + : semaCtx(semaCtx), mod(modOp), eval(ev), + clauses(omp::makeList(clauseList, semaCtx)) { + for (llvm::omp::Directive dir : llvm::omp::getLeafConstructs(compDir)) + leafs.push_back(DirectiveInfo{dir}); + + for (const omp::Clause &clause : clauses) + addClauseSymsToMap(clause, &clause); +} + +[[maybe_unused]] static llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const DirectiveInfo &dirInfo) { + os << llvm::omp::getOpenMPDirectiveName(dirInfo.id); + for (auto [index, clause] : llvm::enumerate(dirInfo.clauses)) { + os << (index == 0 ? '\t' : ' '); + os << llvm::omp::getOpenMPClauseName(clause->id); + } + return os; +} + +[[maybe_unused]] static llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const CompositeInfo &compInfo) { + for (const auto &[index, dirInfo] : llvm::enumerate(compInfo.leafs)) + os << "leaf[" << index << "]: " << dirInfo << '\n'; + + os << "syms:\n"; + for (const auto &[sym, clauses] : compInfo.syms) { + os << *sym << " -> {"; + for (const auto *clause : clauses) + os << ' ' << llvm::omp::getOpenMPClauseName(clause->id); + os << " }\n"; + } + os << "mapBases: {"; + for (const auto &sym : compInfo.mapBases) + os << ' ' << *sym; + os << " }\n"; + return os; +} + +namespace detail { +template +typename std::remove_reference_t::iterator +find_unique(Container &&container, Predicate &&pred) { + auto first = std::find_if(container.begin(), container.end(), pred); + if (first == container.end()) + return first; + auto second = std::find_if(std::next(first), container.end(), pred); + if (second == container.end()) + return first; + return container.end(); +} +} // namespace detail + +static Fortran::semantics::Symbol * +getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) { + return eval.visit(Fortran::common::visitors{ + [&](const Fortran::parser::DoConstruct &doLoop) { + if (const auto &maybeCtrl = doLoop.GetLoopControl()) { + using LoopControl = Fortran::parser::LoopControl; + if (auto *bounds = std::get_if(&maybeCtrl->u)) { + static_assert( + std::is_same_vname), + Fortran::parser::Scalar>); + return bounds->name.thing.symbol; + } + } + return static_cast(nullptr); + }, + [](auto &&) { + return static_cast(nullptr); + }, + }); +} + +void CompositeInfo::addClauseSymsToMap(const omp::Object &object, + const omp::Clause *node) { + syms[object.sym].insert(node); +} + +void CompositeInfo::addClauseSymsToMap(const omp::ObjectList &objects, + const omp::Clause *node) { + for (auto &object : objects) + syms[object.sym].insert(node); +} + +void CompositeInfo::addClauseSymsToMap(const omp::SomeExpr &expr, + const omp::Clause *node) { + // Nothing to do for expressions. +} + +void CompositeInfo::addClauseSymsToMap(const omp::clause::Map &item, + const omp::Clause *node) { + auto &objects = std::get(item.t); + addClauseSymsToMap(objects, node); + for (auto &object : objects) { + if (auto base = omp::getBaseObject(object, semaCtx)) + mapBases.insert(base->sym); + } +} + +template +void CompositeInfo::addClauseSymsToMap(const std::optional &item, + const omp::Clause *node) { + if (item) + addClauseSymsToMap(*item, node); +} + +template +void CompositeInfo::addClauseSymsToMap(const omp::List &item, + const omp::Clause *node) { + for (auto &s : item) + addClauseSymsToMap(s, node); +} + +template +void CompositeInfo::addClauseSymsToMap(const std::tuple &item, + const omp::Clause *node, + std::index_sequence) { + (void)node; // Silence strange warning from GCC. + (addClauseSymsToMap(std::get(item), node), ...); +} + +template >, int> = 0> +void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) { + // Nothing to do for enums. +} + +template ::EmptyTrait::value, int> = 0> +void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) { + // Nothing to do for an empty class. +} + +template < + typename T, + std::enable_if_t::WrapperTrait::value, int> = 0> +void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) { + addClauseSymsToMap(item.v, node); +} + +template ::TupleTrait::value, int> = 0> +void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) { + constexpr size_t tuple_size = + std::tuple_size_v>; + addClauseSymsToMap(item.t, node, std::make_index_sequence{}); +} + +template ::UnionTrait::value, int> = 0> +void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) { + std::visit([&](auto &&s) { addClauseSymsToMap(s, node); }, item.u); +} + +#if 1 +// Apply a clause to the only directive that allows it. If there are no +// directives that allow it, or if there is more that one, do not apply +// anything and return false, otherwise return true. +bool CompositeInfo::applyToUnique(const omp::Clause *node) { + uint32_t version = getOpenMPVersion(mod); + auto unique = detail::find_unique(leafs, [=](const auto &dirInfo) { + return llvm::omp::isAllowedClauseForDirective(dirInfo.id, node->id, + version); + }); + + if (unique != leafs.end()) { + unique->clauses.push_back(node); + return true; + } + return false; +} + +// Apply a clause to the first directive in given range that allows it. +// If such a directive does not exist, return false, otherwise return true. +template +bool CompositeInfo::applyToFirst(const omp::Clause *node, + const mlir::ModuleOp &mod, + llvm::iterator_range range) { + if (range.empty()) + return false; + + uint32_t version = getOpenMPVersion(mod); + for (DirectiveInfo &dir : range) { + if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version)) + continue; + dir.clauses.push_back(node); + return true; + } + return false; +} + +// Apply a clause to the innermost directive that allows it. If such a +// directive does not exist, return false, otherwise return true. +bool CompositeInfo::applyToInnermost(const omp::Clause *node) { + return applyToFirst(node, mod, llvm::reverse(leafs)); +} + +// Apply a clause to the outermost directive that allows it. If such a +// directive does not exist, return false, otherwise return true. +bool CompositeInfo::applyToOutermost(const omp::Clause *node) { + return applyToFirst(node, mod, llvm::iterator_range(leafs)); +} + +template +bool CompositeInfo::applyIf(const omp::Clause *node, Predicate shouldApply) { + bool applied = false; + uint32_t version = getOpenMPVersion(mod); + for (DirectiveInfo &dir : leafs) { + if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version)) + continue; + if (!shouldApply(dir)) + continue; + dir.clauses.push_back(node); + applied = true; + } + + return applied; +} + +bool CompositeInfo::applyToAll(const omp::Clause *node) { + return applyIf(node, [](auto) { return true; }); +} + +template +bool CompositeInfo::applyClause(Clause &&clause, const omp::Clause *node) { + // The default behavior is to find the unique directive to which the + // given clause may be applied. If there are no such directives, or + // if there are multiple ones, flag an error. + // From "OpenMP Application Programming Interface", Version 5.2: + // S Some clauses are permitted only on a single leaf construct of the + // S combined or composite construct, in which case the effect is as if + // S the clause is applied to that specific construct. (p339, 31-33) + if (applyToUnique(node)) + return true; + + return false; +} + +// COLLAPSE +bool CompositeInfo::applyClause(const omp::clause::Collapse &clause, + const omp::Clause *node) { + // Apply COLLAPSE to the innermost directive. If it's not one that + // allows it flag an error. + if (!leafs.empty()) { + DirectiveInfo &last = leafs.back(); + uint32_t version = getOpenMPVersion(mod); + + if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) { + last.clauses.push_back(node); + return true; + } + } + + llvm::errs() << "Cannot apply COLLAPSE\n"; + return false; +} + +// PRIVATE +bool CompositeInfo::applyClause(const omp::clause::Private &clause, + const omp::Clause *node) { + if (applyToInnermost(node)) + return true; + llvm::errs() << "Cannot apply PRIVATE\n"; + return false; +} + +// FIRSTPRIVATE +bool CompositeInfo::applyClause(const omp::clause::Firstprivate &clause, + const omp::Clause *node) { + bool applied = false; + + // S Section 17.2 + // S The effect of the firstprivate clause is as if it is applied to one + // S or more leaf constructs as follows: + + // S - To the distribute construct if it is among the constituent constructs; + // S - To the teams construct if it is among the constituent constructs and + // S the distribute construct is not; + auto hasDistribute = findDirective(llvm::omp::OMPD_distribute); + auto hasTeams = findDirective(llvm::omp::OMPD_teams); + if (hasDistribute != nullptr) { + hasDistribute->clauses.push_back(node); + applied = true; + // S If the teams construct is among the constituent constructs and the + // S effect is not as if the firstprivate clause is applied to it by the + // S above rules, then the effect is as if the shared clause with the + // S same list item is applied to the teams construct. + if (hasTeams != nullptr) { + auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared, + omp::clause::Shared{clause.v}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(shared)); + hasTeams->clauses.push_back(&n); + } + } else if (hasTeams != nullptr) { + hasTeams->clauses.push_back(node); + applied = true; + } + + // S - To a worksharing construct that accepts the clause if one is among + // S the constituent constructs; + auto findWorksharing = [&]() { + auto worksharing = getWorksharing(); + for (DirectiveInfo &dir : leafs) { + auto found = llvm::find(worksharing, dir.id); + if (found != std::end(worksharing)) + return &dir; + } + return static_cast(nullptr); + }; + + auto hasWorksharing = findWorksharing(); + if (hasWorksharing != nullptr) { + hasWorksharing->clauses.push_back(node); + applied = true; + } + + // S - To the taskloop construct if it is among the constituent constructs; + auto hasTaskloop = findDirective(llvm::omp::OMPD_taskloop); + if (hasTaskloop != nullptr) { + hasTaskloop->clauses.push_back(node); + applied = true; + } + + // S - To the parallel construct if it is among the constituent constructs + // S and neither a taskloop construct nor a worksharing construct that + // S accepts the clause is among them; + auto hasParallel = findDirective(llvm::omp::OMPD_parallel); + if (hasParallel != nullptr) { + if (hasTaskloop == nullptr && hasWorksharing == nullptr) { + hasParallel->clauses.push_back(node); + applied = true; } else { - if (const auto *details = - sym->detailsIf()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } + // S If the parallel construct is among the constituent constructs and + // S the effect is not as if the firstprivate clause is applied to it by + // S the above rules, then the effect is as if the shared clause with + // S the same list item is applied to the parallel construct. + auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared, + omp::clause::Shared{clause.v}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(shared)); + hasParallel->clauses.push_back(&n); + } + } + + // S - To the target construct if it is among the constituent constructs + // S and the same list item neither appears in a lastprivate clause nor + // S is the base variable or base pointer of a list item that appears in + // S a map clause. + auto inLastprivate = [&](const omp::Object &object) { + if (ClauseSet *set = findClauses(object)) { + return llvm::find_if(*set, [](const omp::Clause *c) { + return c->id == llvm::omp::Clause::OMPC_lastprivate; + }) != set->end(); } + return false; }; - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - addOperands(*sym); + + auto hasTarget = findDirective(llvm::omp::OMPD_target); + if (hasTarget != nullptr) { + omp::ObjectList objects; + llvm::copy_if( + clause.v, std::back_inserter(objects), [&](const omp::Object &object) { + return !inLastprivate(object) && !mapBases.contains(object.sym); + }); + if (!objects.empty()) { + auto firstp = omp::makeClause(llvm::omp::Clause::OMPC_firstprivate, + omp::clause::Firstprivate{objects}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(firstp)); + hasTarget->clauses.push_back(&n); + applied = true; + } } + + return applied; } -static void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl &symbolAndClause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.emplace_back(clause, *name->symbol); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.emplace_back(clause, *name.symbol); - }}, - ompObject.u); +// LASTPRIVATE +bool CompositeInfo::applyClause(const omp::clause::Lastprivate &clause, + const omp::Clause *node) { + bool applied = false; + + // S The effect of the lastprivate clause is as if it is applied to all leaf + // S constructs that permit the clause. + if (!applyToAll(node)) { + llvm::errs() << "Cannot apply LASTPRIVATE\n"; + return false; + } + + auto inFirstprivate = [&](const omp::Object &object) { + if (ClauseSet *set = findClauses(object)) { + return llvm::find_if(*set, [](const omp::Clause *c) { + return c->id == llvm::omp::Clause::OMPC_firstprivate; + }) != set->end(); + } + return false; + }; + + // Prepare list of objects that could end up in a SHARED clause. + omp::ObjectList sharedObjects; + llvm::copy_if( + clause.v, std::back_inserter(sharedObjects), + [&](const omp::Object &object) { return !inFirstprivate(object); }); + + if (!sharedObjects.empty()) { + // S If the parallel construct is among the constituent constructs and the + // S list item is not also specified in the firstprivate clause, then the + // S effect of the lastprivate clause is as if the shared clause with the + // S same list item is applied to the parallel construct. + if (auto hasParallel = findDirective(llvm::omp::OMPD_parallel)) { + auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared, + omp::clause::Shared{sharedObjects}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(shared)); + hasParallel->clauses.push_back(&n); + applied = true; + } + + // S If the teams construct is among the constituent constructs and the + // S list item is not also specified in the firstprivate clause, then the + // S effect of the lastprivate clause is as if the shared clause with the + // S same list item is applied to the teams construct. + if (auto hasTeams = findDirective(llvm::omp::OMPD_teams)) { + auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared, + omp::clause::Shared{sharedObjects}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(shared)); + hasTeams->clauses.push_back(&n); + applied = true; + } + } + + // S If the target construct is among the constituent constructs and the + // S list item is not the base variable or base pointer of a list item that + // S appears in a map clause, the effect of the lastprivate clause is as if + // S the same list item appears in a map clause with a map-type of tofrom. + if (auto hasTarget = findDirective(llvm::omp::OMPD_target)) { + omp::ObjectList tofrom; + llvm::copy_if(clause.v, std::back_inserter(tofrom), + [&](const omp::Object &object) { + return !mapBases.contains(object.sym); + }); + + if (!tofrom.empty()) { + auto mapType = omp::clause::Map::MapType{ + {std::nullopt, omp::clause::Map::MapType::Type::Tofrom}}; + auto map = + omp::makeClause(llvm::omp::Clause::OMPC_map, + omp::clause::Map{{mapType, std::move(tofrom)}}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(map)); + hasTarget->clauses.push_back(&n); + applied = true; + } } + + return applied; } -static Fortran::lower::pft::Evaluation * -getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) { - // Return the Evaluation of the innermost collapsed loop, or the current one - // if there was no COLLAPSE. - if (collapseValue == 0) - return &eval; +// SHARED +bool CompositeInfo::applyClause(const omp::clause::Shared &clause, + const omp::Clause *node) { + // Apply SHARED to the all leafs that allow it. + if (applyToAll(node)) + return true; + llvm::errs() << "Cannot apply SHARED\n"; + return false; +} - Fortran::lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation(); - for (int i = 1; i < collapseValue; i++) { - // The nested evaluations should be DoConstructs (i.e. they should form - // a loop nest). Each DoConstruct is a tuple . - assert(curEval->isA()); - curEval = &*std::next(curEval->getNestedEvaluations().begin()); +// DEFAULT +bool CompositeInfo::applyClause(const omp::clause::Default &clause, + const omp::Clause *node) { + // Apply DEFAULT to the all leafs that allow it. + if (applyToAll(node)) + return true; + llvm::errs() << "Cannot apply DEFAULT\n"; + return false; +} + +// THREAD_LIMIT +bool CompositeInfo::applyClause(const omp::clause::ThreadLimit &clause, + const omp::Clause *node) { + // Apply THREAD_LIMIT to the all leafs that allow it. + if (applyToAll(node)) + return true; + llvm::errs() << "Cannot apply THREAD_LIMIT\n"; + return false; +} + +// ORDER +bool CompositeInfo::applyClause(const omp::clause::Order &clause, + const omp::Clause *node) { + // Apply ORDER to the all leafs that allow it. + if (applyToAll(node)) + return true; + llvm::errs() << "Cannot apply ORDER\n"; + return false; +} + +// ALLOCATE +bool CompositeInfo::applyClause(const omp::clause::Allocate &clause, + const omp::Clause *node) { + // This one needs to be applied at the end, once we know which clauses are + // assigned to which leaf constructs. + + // S The effect of the allocate clause is as if it is applied to all leaf + // S constructs that permit the clause and to which a data-sharing attribute + // S clause that may create a private copy of the same list item is applied. + + auto canMakePrivateCopy = [](llvm::omp::Clause id) { + switch (id) { + case llvm::omp::Clause::OMPC_firstprivate: + case llvm::omp::Clause::OMPC_lastprivate: + case llvm::omp::Clause::OMPC_private: + return true; + default: + return false; + } + }; + + bool applied = applyIf(node, [&](const DirectiveInfo &dir) { + return llvm::any_of(dir.clauses, [&](const omp::Clause *n) { + return canMakePrivateCopy(n->id); + }); + }); + + return applied; +} + +// REDUCTION +bool CompositeInfo::applyClause(const omp::clause::Reduction &clause, + const omp::Clause *node) { + // S The effect of the reduction clause is as if it is applied to all leaf + // S constructs that permit the clause, except for the following constructs: + // S - The parallel construct, when combined with the sections, worksharing- + // S loop, loop, or taskloop construct; and + // S - The teams construct, when combined with the loop construct. + bool applyToParallel = true, applyToTeams = true; + + auto hasParallel = findDirective(llvm::omp::Directive::OMPD_parallel); + if (hasParallel) { + auto exclusions = llvm::concat( + getWorksharingLoop(), llvm::ArrayRef{ + llvm::omp::Directive::OMPD_loop, + llvm::omp::Directive::OMPD_sections, + llvm::omp::Directive::OMPD_taskloop, + }); + auto present = [&](llvm::omp::Directive id) { + return findDirective(id) != nullptr; + }; + + if (llvm::any_of(exclusions, present)) + applyToParallel = false; } - return curEval; + + auto hasTeams = findDirective(llvm::omp::Directive::OMPD_teams); + if (hasTeams) { + // The only exclusion is OMPD_loop. + if (findDirective(llvm::omp::Directive::OMPD_loop)) + applyToTeams = false; + } + + auto &objects = std::get(clause.t); + + omp::ObjectList sharedObjects; + llvm::transform(objects, std::back_inserter(sharedObjects), + [&](const omp::Object &object) { + auto maybeBase = getBaseObject(object, semaCtx); + return maybeBase ? *maybeBase : object; + }); + + // S For the parallel and teams constructs above, the effect of the + // S reduction clause instead is as if each list item or, for any list + // S item that is an array item, its corresponding base array or base + // S pointer appears in a shared clause for the construct. + if (!sharedObjects.empty()) { + if (hasParallel && !applyToParallel) { + auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared, + omp::clause::Shared{sharedObjects}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(shared)); + hasParallel->clauses.push_back(&n); + } + if (hasTeams && !applyToTeams) { + auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared, + omp::clause::Shared{sharedObjects}); + const omp::Clause &n = *extras.insert(extras.end(), std::move(shared)); + hasTeams->clauses.push_back(&n); + } + } + + // TODO(not implemented in parser yet): Apply the following. + // S If the task reduction-modifier is specified, the effect is as if + // S it only modifies the behavior of the reduction clause on the innermost + // S leaf construct that accepts the modifier (see Section 5.5.8). If the + // S inscan reduction-modifier is specified, the effect is as if it modifies + // S the behavior of the reduction clause on all constructs of the combined + // S construct to which the clause is applied and that accept the modifier. + + bool applied = applyIf(node, [&](DirectiveInfo &dir) { + if (!applyToParallel && &dir == hasParallel) + return false; + if (!applyToTeams && &dir == hasTeams) + return false; + return true; + }); + + // S If a list item in a reduction clause on a combined target construct + // S does not have the same base variable or base pointer as a list item + // S in a map clause on the construct, then the effect is as if the list + // S item in the reduction clause appears as a list item in a map clause + // S with a map-type of tofrom. + auto hasTarget = findDirective(llvm::omp::Directive::OMPD_target); + if (hasTarget && leafs.size() > 1) { + omp::ObjectList tofrom; + llvm::copy_if(objects, std::back_inserter(tofrom), + [&](const omp::Object &object) { + if (auto maybeBase = getBaseObject(object, semaCtx)) + return !mapBases.contains(maybeBase->sym); + return !mapBases.contains(object.sym); // XXX is this ok? + }); + if (!tofrom.empty()) { + auto mapType = omp::clause::Map::MapType{ + {std::nullopt, omp::clause::Map::MapType::Type::Tofrom}}; + auto map = + omp::makeClause(llvm::omp::Clause::OMPC_map, + omp::clause::Map{{mapType, std::move(tofrom)}}); + + const omp::Clause &n = *extras.insert(extras.end(), std::move(map)); + hasTarget->clauses.push_back(&n); + applied = true; + } + } + + return applied; } -static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - int collapseValue = 0) { - Fortran::lower::pft::Evaluation *curEval = - getCollapsedLoopEval(eval, collapseValue); +// IF +bool CompositeInfo::applyClause(const omp::clause::If &clause, + const omp::Clause *node) { + using DirectiveNameModifier = omp::clause::If::DirectiveNameModifier; + auto &modifier = std::get>(clause.t); - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) - converter.genEval(e); + if (modifier) { + llvm::omp::Directive dirId = llvm::omp::Directive::OMPD_unknown; + + switch (*modifier) { + case DirectiveNameModifier::Parallel: + dirId = llvm::omp::Directive::OMPD_parallel; + break; + case DirectiveNameModifier::Simd: + dirId = llvm::omp::Directive::OMPD_simd; + break; + case DirectiveNameModifier::Target: + dirId = llvm::omp::Directive::OMPD_target; + break; + case DirectiveNameModifier::Task: + dirId = llvm::omp::Directive::OMPD_task; + break; + case DirectiveNameModifier::Taskloop: + dirId = llvm::omp::Directive::OMPD_taskloop; + break; + case DirectiveNameModifier::Teams: + dirId = llvm::omp::Directive::OMPD_teams; + break; + + case DirectiveNameModifier::TargetData: + case DirectiveNameModifier::TargetEnterData: + case DirectiveNameModifier::TargetExitData: + case DirectiveNameModifier::TargetUpdate: + default: + llvm::errs() << "Invalid modifier in IF clause\n"; + return false; + } + + if (auto *hasDir = findDirective(dirId)) { + hasDir->clauses.push_back(node); + return true; + } + llvm::errs() << "Directive from modifier not found\n"; + return false; + } + + if (applyToAll(node)) + return true; + + llvm::errs() << "Cannot apply IF\n"; + return false; +} + +// LINEAR +bool CompositeInfo::applyClause(const omp::clause::Linear &clause, + const omp::Clause *node) { + // S The effect of the linear clause is as if it is applied to the innermost + // S leaf construct. + if (applyToInnermost(node)) { + llvm::errs() << "Cannot apply LINEAR\n"; + return false; + } + + // The rest is about SIMD. + if (!findDirective(llvm::omp::OMPD_simd)) + return true; + + // S Additionally, if the list item is not the iteration variable of a + // S simd or worksharing-loop SIMD construct, the effect on the outer leaf + // S constructs is as if the list item was specified in firstprivate and + // S lastprivate clauses on the combined or composite construct, [...] + // + // S If a list item of the linear clause is the iteration variable of a + // S simd or worksharing-loop SIMD construct and it is not declared in + // S the construct, the effect on the outer leaf constructs is as if the + // S list item was specified in a lastprivate clause on the combined or + // S composite construct [...] + + // It's not clear how an object can be listed in a clause AND be the + // iteration variable of a construct in which is it declared. If an + // object is declared in the construct, then the declaration is located + // after the clause listing it. + + Fortran::semantics::Symbol *iterVarSym = getIterationVariableSymbol(eval); + const auto &objects = std::get(clause.t); + + // Lists of objects that will be used to construct FIRSTPRIVATE and + // LASTPRIVATE clauses. + omp::ObjectList first, last; + + for (const omp::Object &object : objects) { + last.push_back(object); + if (object.sym != iterVarSym) + first.push_back(object); + } + + if (!first.empty()) { + auto firstp = omp::makeClause(llvm::omp::Clause::OMPC_firstprivate, + omp::clause::Firstprivate{first}); + clauses.push_back(std::move(firstp)); // Appending to the main clause list. + } + if (!last.empty()) { + auto lastp = omp::makeClause(llvm::omp::Clause::OMPC_lastprivate, + omp::clause::Lastprivate{last}); + clauses.push_back(std::move(lastp)); // Appending to the main clause list. + } + return true; +} + +// NOWAIT +bool CompositeInfo::applyClause(const omp::clause::Nowait &clause, + const omp::Clause *node) { + if (applyToOutermost(node)) + return true; + llvm::errs() << "Cannot apply NOWAIT\n"; + return false; +} + +bool CompositeInfo::split() { + bool success = true; + + // First we need to apply LINEAR, because it can generate additional + // FIRSTPRIVATE and LASTPRIVATE clauses that apply to the combined/ + // composite construct. + // Collect them separately, because they may modify the clause list. + llvm::SmallVector linears; + for (const omp::Clause &node : clauses) { + if (node.id == llvm::omp::Clause::OMPC_linear) + linears.push_back(&node); + } + for (const auto *node : linears) { + success = + success && applyClause(std::get(node->u), node); + } + + // ALLOCATE clauses need to be applied last since they need to see + // which directives have data-privatizing clauses. + auto skip = [](const omp::Clause *node) { + switch (node->id) { + case llvm::omp::Clause::OMPC_allocate: + case llvm::omp::Clause::OMPC_linear: + return true; + default: + return false; + } + }; + + // Apply (almost) all clauses. + for (const omp::Clause &node : clauses) { + if (skip(&node)) + continue; + success = + success && + std::visit([&](auto &&s) { return applyClause(s, &node); }, node.u); + } + + // Apply ALLOCATE. + for (const omp::Clause &node : clauses) { + if (node.id != llvm::omp::Clause::OMPC_allocate) + continue; + success = + success && + std::visit([&](auto &&s) { return applyClause(s, &node); }, node.u); + } + + return success; +} +#endif + +static void splitCompositeConstruct( + const mlir::ModuleOp &modOp, Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive compDir, + const Fortran::parser::OmpClauseList &clauseList) { + // llvm::errs() << "composite name:" + // << llvm::omp::getOpenMPDirectiveName(compDir) << '\n'; + // llvm::errs() << "clause list:"; + for (auto &clause : clauseList.v) { + // std::visit([&](auto &&s) { omp::clause::make(s, semaCtx); }, + // clause.u); llvm::errs() << ' ' << + // llvm::omp::getOpenMPClauseName(getClauseId(clause)); + } + // llvm::errs() << '\n'; + + CompositeInfo compInfo(modOp, semaCtx, eval, compDir, clauseList); + // llvm::errs() << "compInfo.1\n" << compInfo << '\n'; + + bool success = compInfo.split(); + + // Dump + // llvm::errs() << "success:" << success << '\n'; + // llvm::errs() << "compInfo.2\n" << compInfo << '\n'; } //===----------------------------------------------------------------------===// @@ -157,14 +2298,15 @@ class DataSharingProcessor { llvm::SetVector symbolsInNestedRegions; llvm::SetVector symbolsInParentRegions; Fortran::lower::AbstractConverter &converter; + Fortran::semantics::SemanticsContext &semaCtx; fir::FirOpBuilder &firOpBuilder; - const Fortran::parser::OmpClauseList &opClauseList; + omp::List clauses; Fortran::lower::pft::Evaluation &eval; bool needBarrier(); void collectSymbols(Fortran::semantics::Symbol::Flag flag); void collectOmpObjectListSymbol( - const Fortran::parser::OmpObjectList &ompObjectList, + const omp::ObjectList &objects, llvm::SetVector &symbolSet); void collectSymbolsForPrivatization(); void insertBarrier(); @@ -181,11 +2323,12 @@ class DataSharingProcessor { public: DataSharingProcessor(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, const Fortran::parser::OmpClauseList &opClauseList, Fortran::lower::pft::Evaluation &eval) - : hasLastPrivateOp(false), converter(converter), - firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList), - eval(eval) {} + : hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx), + firOpBuilder(converter.getFirOpBuilder()), + clauses(omp::makeList(opClauseList, semaCtx)), eval(eval) {} // Privatisation is split into two steps. // Step1 performs cloning of all privatisation clauses and copying for // firstprivates. Step1 is performed at the place where process/processStep1 @@ -263,30 +2406,28 @@ void DataSharingProcessor::copyLastPrivateSymbol( } void DataSharingProcessor::collectOmpObjectListSymbol( - const Fortran::parser::OmpObjectList &ompObjectList, + const omp::ObjectList &objects, llvm::SetVector &symbolSet) { - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const omp::Object &object : objects) { + Fortran::semantics::Symbol *sym = object.sym; symbolSet.insert(sym); } } void DataSharingProcessor::collectSymbolsForPrivatization() { bool hasCollapse = false; - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + for (const omp::Clause &clause : clauses) { if (const auto &privateClause = - std::get_if(&clause.u)) { + std::get_if(&clause.u)) { collectOmpObjectListSymbol(privateClause->v, privatizedSymbols); } else if (const auto &firstPrivateClause = - std::get_if( - &clause.u)) { + std::get_if(&clause.u)) { collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols); } else if (const auto &lastPrivateClause = - std::get_if( - &clause.u)) { + std::get_if(&clause.u)) { collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols); hasLastPrivateOp = true; - } else if (std::get_if(&clause.u)) { + } else if (std::get_if(&clause.u)) { hasCollapse = true; } } @@ -319,138 +2460,135 @@ void DataSharingProcessor::insertBarrier() { void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { bool cmpCreated = false; mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint(); - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (std::get_if(&clause.u)) { - // TODO: Add lastprivate support for simd construct - if (mlir::isa(op)) { - if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { - // For `omp.sections`, lastprivatized variables occur in - // lexically final `omp.section` operation. The following FIR - // shall be generated for the same: - // - // omp.sections lastprivate(...) { - // omp.section {...} - // omp.section {...} - // omp.section { - // fir.allocate for `private`/`firstprivate` - // - // fir.if %true { - // ^%lpv_update_blk - // } - // } - // } - // - // To keep code consistency while handling privatization - // through this control flow, add a `fir.if` operation - // that always evaluates to true, in order to create - // a dedicated sub-region in `omp.section` where - // lastprivate FIR can reside. Later canonicalizations - // will optimize away this operation. - if (!eval.lowerAsUnstructured()) { - auto ifOp = firOpBuilder.create( - op->getLoc(), - firOpBuilder.createIntegerConstant( - op->getLoc(), firOpBuilder.getIntegerType(1), 0x1), - /*else*/ false); - firOpBuilder.setInsertionPointToStart( - &ifOp.getThenRegion().front()); - - const Fortran::parser::OpenMPConstruct *parentOmpConstruct = - eval.parentConstruct->getIf(); - assert(parentOmpConstruct && - "Expected a valid enclosing OpenMP construct"); - const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct = - std::get_if( - &parentOmpConstruct->u); - assert(sectionsConstruct && - "Expected an enclosing omp.sections construct"); - const Fortran::parser::OmpClauseList §ionsEndClauseList = - std::get( - std::get( - sectionsConstruct->t) - .t); - for (const Fortran::parser::OmpClause &otherClause : - sectionsEndClauseList.v) - if (std::get_if( - &otherClause.u)) - // Emit implicit barrier to synchronize threads and avoid data - // races on post-update of lastprivate variables when `nowait` - // clause is present. - firOpBuilder.create( - converter.getCurrentLocation()); - firOpBuilder.setInsertionPointToStart( - &ifOp.getThenRegion().front()); - lastPrivIP = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(ifOp); - insPt = firOpBuilder.saveInsertionPoint(); - } else { - // Lastprivate operation is inserted at the end - // of the lexically last section in the sections - // construct - mlir::OpBuilder::InsertPoint unstructuredSectionsIP = - firOpBuilder.saveInsertionPoint(); - mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); - firOpBuilder.setInsertionPoint(lastOper); - lastPrivIP = firOpBuilder.saveInsertionPoint(); - firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP); - } - } - } else if (mlir::isa(op)) { - // Update the original variable just before exiting the worksharing - // loop. Conversion as follows: + for (const omp::Clause &clause : clauses) { + if (clause.id != llvm::omp::OMPC_lastprivate) + continue; + // TODO: Add lastprivate support for simd construct + if (mlir::isa(op)) { + if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { + // For `omp.sections`, lastprivatized variables occur in + // lexically final `omp.section` operation. The following FIR + // shall be generated for the same: // - // omp.wsloop { - // omp.wsloop { ... - // ... store - // store ===> %v = arith.addi %iv, %step - // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub - // } fir.if %cmp { - // fir.store %v to %loopIV - // ^%lpv_update_blk: - // } - // omp.yield - // } + // omp.sections lastprivate(...) { + // omp.section {...} + // omp.section {...} + // omp.section { + // fir.allocate for `private`/`firstprivate` + // + // fir.if %true { + // ^%lpv_update_blk + // } + // } + // } // - - // Only generate the compare once in presence of multiple LastPrivate - // clauses. - if (cmpCreated) - continue; - cmpCreated = true; - - mlir::Location loc = op->getLoc(); - mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); - firOpBuilder.setInsertionPoint(lastOper); - - mlir::Value iv = op->getRegion(0).front().getArguments()[0]; - mlir::Value ub = - mlir::dyn_cast(op).getUpperBound()[0]; - mlir::Value step = mlir::dyn_cast(op).getStep()[0]; - - // v = iv + step - // cmp = step < 0 ? v < ub : v > ub - mlir::Value v = firOpBuilder.create(loc, iv, step); - mlir::Value zero = - firOpBuilder.createIntegerConstant(loc, step.getType(), 0); - mlir::Value negativeStep = firOpBuilder.create( - loc, mlir::arith::CmpIPredicate::slt, step, zero); - mlir::Value vLT = firOpBuilder.create( - loc, mlir::arith::CmpIPredicate::slt, v, ub); - mlir::Value vGT = firOpBuilder.create( - loc, mlir::arith::CmpIPredicate::sgt, v, ub); - mlir::Value cmpOp = firOpBuilder.create( - loc, negativeStep, vLT, vGT); - - auto ifOp = firOpBuilder.create(loc, cmpOp, /*else*/ false); - firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - assert(loopIV && "loopIV was not set"); - firOpBuilder.create(op->getLoc(), v, loopIV); - lastPrivIP = firOpBuilder.saveInsertionPoint(); - } else { - TODO(converter.getCurrentLocation(), - "lastprivate clause in constructs other than " - "simd/worksharing-loop"); + // To keep code consistency while handling privatization + // through this control flow, add a `fir.if` operation + // that always evaluates to true, in order to create + // a dedicated sub-region in `omp.section` where + // lastprivate FIR can reside. Later canonicalizations + // will optimize away this operation. + if (!eval.lowerAsUnstructured()) { + auto ifOp = firOpBuilder.create( + op->getLoc(), + firOpBuilder.createIntegerConstant( + op->getLoc(), firOpBuilder.getIntegerType(1), 0x1), + /*else*/ false); + firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + + const Fortran::parser::OpenMPConstruct *parentOmpConstruct = + eval.parentConstruct->getIf(); + assert(parentOmpConstruct && + "Expected a valid enclosing OpenMP construct"); + const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct = + std::get_if( + &parentOmpConstruct->u); + assert(sectionsConstruct && + "Expected an enclosing omp.sections construct"); + const Fortran::parser::OmpClauseList §ionsEndClauseList = + std::get( + std::get( + sectionsConstruct->t) + .t); + for (const Fortran::parser::OmpClause &otherClause : + sectionsEndClauseList.v) + if (std::get_if(&otherClause.u)) + // Emit implicit barrier to synchronize threads and avoid data + // races on post-update of lastprivate variables when `nowait` + // clause is present. + firOpBuilder.create( + converter.getCurrentLocation()); + firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(ifOp); + insPt = firOpBuilder.saveInsertionPoint(); + } else { + // Lastprivate operation is inserted at the end + // of the lexically last section in the sections + // construct + mlir::OpBuilder::InsertPoint unstructuredSectionsIP = + firOpBuilder.saveInsertionPoint(); + mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); + firOpBuilder.setInsertionPoint(lastOper); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP); + } } + } else if (mlir::isa(op)) { + // Update the original variable just before exiting the worksharing + // loop. Conversion as follows: + // + // omp.wsloop { + // omp.wsloop { ... + // ... store + // store ===> %v = arith.addi %iv, %step + // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub + // } fir.if %cmp { + // fir.store %v to %loopIV + // ^%lpv_update_blk: + // } + // omp.yield + // } + // + + // Only generate the compare once in presence of multiple LastPrivate + // clauses. + if (cmpCreated) + continue; + cmpCreated = true; + + mlir::Location loc = op->getLoc(); + mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); + firOpBuilder.setInsertionPoint(lastOper); + + mlir::Value iv = op->getRegion(0).front().getArguments()[0]; + mlir::Value ub = + mlir::dyn_cast(op).getUpperBound()[0]; + mlir::Value step = mlir::dyn_cast(op).getStep()[0]; + + // v = iv + step + // cmp = step < 0 ? v < ub : v > ub + mlir::Value v = firOpBuilder.create(loc, iv, step); + mlir::Value zero = + firOpBuilder.createIntegerConstant(loc, step.getType(), 0); + mlir::Value negativeStep = firOpBuilder.create( + loc, mlir::arith::CmpIPredicate::slt, step, zero); + mlir::Value vLT = firOpBuilder.create( + loc, mlir::arith::CmpIPredicate::slt, v, ub); + mlir::Value vGT = firOpBuilder.create( + loc, mlir::arith::CmpIPredicate::sgt, v, ub); + mlir::Value cmpOp = firOpBuilder.create( + loc, negativeStep, vLT, vGT); + + auto ifOp = firOpBuilder.create(loc, cmpOp, /*else*/ false); + firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + assert(loopIV && "loopIV was not set"); + firOpBuilder.create(op->getLoc(), v, loopIV); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + } else { + TODO(converter.getCurrentLocation(), + "lastprivate clause in constructs other than " + "simd/worksharing-loop"); } } firOpBuilder.restoreInsertionPoint(localInsPt); @@ -474,14 +2612,12 @@ void DataSharingProcessor::collectSymbols( } void DataSharingProcessor::collectDefaultSymbols() { - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (const auto &defaultClause = - std::get_if(&clause.u)) { - if (defaultClause->v.v == - Fortran::parser::OmpDefaultClause::Type::Private) + for (const omp::Clause &clause : clauses) { + if (const auto *defaultClause = + std::get_if(&clause.u)) { + if (defaultClause->v == omp::clause::Default::Type::Private) collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate); - else if (defaultClause->v.v == - Fortran::parser::OmpDefaultClause::Type::Firstprivate) + else if (defaultClause->v == omp::clause::Default::Type::Firstprivate) collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate); } } @@ -548,13 +2684,12 @@ void DataSharingProcessor::defaultPrivatize() { /// methods that relate to clauses that can impact the lowering of that /// construct. class ClauseProcessor { - using ClauseTy = Fortran::parser::OmpClause; - public: ClauseProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, const Fortran::parser::OmpClauseList &clauses) - : converter(converter), semaCtx(semaCtx), clauses(clauses) {} + : converter(converter), semaCtx(semaCtx), + clauses(omp::makeList(clauses, semaCtx)) {} // 'Unique' clauses: They can appear at most once in the clause list. bool @@ -602,9 +2737,8 @@ class ClauseProcessor { llvm::SmallVectorImpl &dependOperands) const; bool processEnter(llvm::SmallVectorImpl &result) const; - bool - processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const; + bool processIf(omp::clause::If::DirectiveNameModifier directiveName, + mlir::Value &result) const; bool processLink(llvm::SmallVectorImpl &result) const; @@ -654,7 +2788,7 @@ class ClauseProcessor { llvm::omp::Directive directive) const; private: - using ClauseIterator = std::list::const_iterator; + using ClauseIterator = omp::List::const_iterator; /// Utility to find a clause within a range in the clause list. template @@ -673,8 +2807,8 @@ class ClauseProcessor { template const T * findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const { - ClauseIterator it = findClause(clauses.v.begin(), clauses.v.end()); - if (it != clauses.v.end()) { + ClauseIterator it = findClause(clauses.begin(), clauses.end()); + if (it != clauses.end()) { if (source) *source = &it->source; return &std::get(it->u); @@ -686,15 +2820,15 @@ class ClauseProcessor { /// if at least one instance was found. template bool findRepeatableClause( - std::function + std::function callbackFn) const { bool found = false; - ClauseIterator nextIt, endIt = clauses.v.end(); - for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) { + ClauseIterator nextIt, endIt = clauses.end(); + for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { nextIt = findClause(it, endIt); if (nextIt != endIt) { - callbackFn(&std::get(nextIt->u), nextIt->source); + callbackFn(std::get(nextIt->u), nextIt->source); found = true; ++nextIt; } @@ -714,7 +2848,7 @@ class ClauseProcessor { Fortran::lower::AbstractConverter &converter; Fortran::semantics::SemanticsContext &semaCtx; - const Fortran::parser::OmpClauseList &clauses; + omp::List clauses; }; //===----------------------------------------------------------------------===// @@ -750,9 +2884,9 @@ class ReductionProcessor { IEOR }; static ReductionIdentifier - getReductionType(const Fortran::parser::ProcedureDesignator &pd) { + getReductionType(const omp::clause::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch>( - getRealName(pd).ToString()) + getRealName(pd.v.sym).ToString()) .Case("max", ReductionIdentifier::MAX) .Case("min", ReductionIdentifier::MIN) .Case("iand", ReductionIdentifier::IAND) @@ -764,35 +2898,33 @@ class ReductionProcessor { } static ReductionIdentifier getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: return ReductionIdentifier::ADD; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: return ReductionIdentifier::SUBTRACT; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: return ReductionIdentifier::MULTIPLY; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return ReductionIdentifier::AND; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return ReductionIdentifier::EQV; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return ReductionIdentifier::OR; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return ReductionIdentifier::NEQV; default: llvm_unreachable("unexpected intrinsic operator in reduction"); } } - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - if (!name->symbol->GetUltimate().attrs().test( - Fortran::semantics::Attr::INTRINSIC)) + static bool + supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) { + Fortran::semantics::Symbol *sym = pd.v.sym; + if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) return false; - auto redType = llvm::StringSwitch(getRealName(name).ToString()) + auto redType = llvm::StringSwitch(getRealName(sym).ToString()) .Case("max", true) .Case("min", true) .Case("iand", true) @@ -803,15 +2935,13 @@ class ReductionProcessor { } static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); + getRealName(const Fortran::semantics::Symbol *symbol) { + return symbol->GetUltimate().name(); } static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); + getRealName(const omp::clause::ProcedureDesignator &pd) { + return getRealName(pd.v.sym); } static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { @@ -821,25 +2951,25 @@ class ReductionProcessor { .str(); } - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { + static std::string + getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { std::string reductionName; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return "neqv_reduction"; default: reductionName = "other_reduction"; @@ -1083,7 +3213,7 @@ class ReductionProcessor { static void addReductionDecl(mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl @@ -1091,13 +3221,12 @@ class ReductionProcessor { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ - std::get(reduction.t)}; - const auto &objectList{ - std::get(reduction.t)}; + std::get(reduction.t)}; + const auto &objectList{std::get(reduction.t)}; if (const auto &redDefinedOp = - std::get_if(&redOperator.u)) { + std::get_if(&redOperator.u)) { const auto &intrinsicOp{ - std::get( + std::get( redDefinedOp->u)}; ReductionIdentifier redId = getReductionType(intrinsicOp); switch (redId) { @@ -1113,10 +3242,41 @@ class ReductionProcessor { "Reduction of some intrinsic operators is not supported"); break; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, + redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + redId, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } else if (const auto *reductionIntrinsic = + std::get_if( + &redOperator.u)) { + if (ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) { + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { if (reductionSymbols) reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); @@ -1125,118 +3285,72 @@ class ReductionProcessor { mlir::Type redType = symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); - if (redType.isa()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - redId, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - redId, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType), + redId, redType, currentLocation); reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } } } - } else if (const auto *reductionIntrinsic = - std::get_if( - &redOperator.u)) { - if (ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) { - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType), - redId, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } } } }; static mlir::omp::ScheduleModifier -translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) { - switch (m.v) { - case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: +translateScheduleModifier(const omp::clause::Schedule::ModType &m) { + switch (m) { + case omp::clause::Schedule::ModType::Monotonic: return mlir::omp::ScheduleModifier::monotonic; - case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: + case omp::clause::Schedule::ModType::Nonmonotonic: return mlir::omp::ScheduleModifier::nonmonotonic; - case Fortran::parser::OmpScheduleModifierType::ModType::Simd: + case omp::clause::Schedule::ModType::Simd: return mlir::omp::ScheduleModifier::simd; } return mlir::omp::ScheduleModifier::none; } static mlir::omp::ScheduleModifier -getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { - const auto &modifier = - std::get>(x.t); +getScheduleModifier(const omp::clause::Schedule &clause) { + using ScheduleModifier = omp::clause::Schedule::ScheduleModifier; + const auto &modifier = std::get>(clause.t); // The input may have the modifier any order, so we look for one that isn't // SIMD. If modifier is not set at all, fall down to the bottom and return // "none". if (modifier) { - const auto &modType1 = - std::get(modifier->t); - if (modType1.v.v == - Fortran::parser::OmpScheduleModifierType::ModType::Simd) { - const auto &modType2 = std::get< - std::optional>( - modifier->t); - if (modType2 && - modType2->v.v != - Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return translateScheduleModifier(modType2->v); - + using ModType = omp::clause::Schedule::ModType; + const auto &modType1 = std::get(modifier->t); + if (modType1 == ModType::Simd) { + const auto &modType2 = std::get>(modifier->t); + if (modType2 && *modType2 != ModType::Simd) + return translateScheduleModifier(*modType2); return mlir::omp::ScheduleModifier::none; } - return translateScheduleModifier(modType1.v); + return translateScheduleModifier(modType1); } return mlir::omp::ScheduleModifier::none; } static mlir::omp::ScheduleModifier -getSimdModifier(const Fortran::parser::OmpScheduleClause &x) { - const auto &modifier = - std::get>(x.t); +getSimdModifier(const omp::clause::Schedule &clause) { + using ScheduleModifier = omp::clause::Schedule::ScheduleModifier; + const auto &modifier = std::get>(clause.t); // Either of the two possible modifiers in the input can be the SIMD modifier, // so look in either one, and return simd if we find one. Not found = return // "none". if (modifier) { - const auto &modType1 = - std::get(modifier->t); - if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) + using ModType = omp::clause::Schedule::ModType; + const auto &modType1 = std::get(modifier->t); + if (modType1 == ModType::Simd) return mlir::omp::ScheduleModifier::simd; - const auto &modType2 = std::get< - std::optional>( - modifier->t); - if (modType2 && modType2->v.v == - Fortran::parser::OmpScheduleModifierType::ModType::Simd) + const auto &modType2 = std::get>(modifier->t); + if (modType2 && *modType2 == ModType::Simd) return mlir::omp::ScheduleModifier::simd; } return mlir::omp::ScheduleModifier::none; @@ -1244,7 +3358,7 @@ getSimdModifier(const Fortran::parser::OmpScheduleClause &x) { static void genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, + const omp::clause::Allocate &clause, llvm::SmallVectorImpl &allocatorOperands, llvm::SmallVectorImpl &allocateOperands) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -1252,21 +3366,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext stmtCtx; mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get(ompAllocateClause.t); - const auto &allocateModifier = std::get< - std::optional>( - ompAllocateClause.t); + const omp::ObjectList &objectList = std::get(clause.t); + const auto &modifier = + std::get>(clause.t); // If the allocate modifier is present, check if we only use the allocator // submodifier. ALIGN in this context is unimplemented const bool onlyAllocator = - allocateModifier && - std::holds_alternative< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); + modifier && + std::holds_alternative( + modifier->u); - if (allocateModifier && !onlyAllocator) { + if (modifier && !onlyAllocator) { TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); } @@ -1274,37 +3385,34 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, // to list of allocators, otherwise, add default allocator to // list of allocators. if (onlyAllocator) { - const auto &allocatorValue = std::get< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - allocatorOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + const auto &value = + std::get(modifier->u); + mlir::Value operand = + fir::getBase(converter.genExprValue(value.v, stmtCtx)); + allocatorOperands.append(objectList.size(), operand); } else { - allocatorOperand = firOpBuilder.createIntegerConstant( + mlir::Value operand = firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + allocatorOperands.append(objectList.size(), operand); } - genObjectList(ompObjectList, converter, allocateOperands); + genObjectList(objectList, converter, allocateOperands); } -static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr( - fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::ProcBind *procBindClause) { +static mlir::omp::ClauseProcBindKindAttr +genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder, + const omp::clause::ProcBind &clause) { mlir::omp::ClauseProcBindKind procBindKind; - switch (procBindClause->v.v) { - case Fortran::parser::OmpProcBindClause::Type::Master: + switch (clause.v) { + case omp::clause::ProcBind::Type::Master: procBindKind = mlir::omp::ClauseProcBindKind::Master; break; - case Fortran::parser::OmpProcBindClause::Type::Close: + case omp::clause::ProcBind::Type::Close: procBindKind = mlir::omp::ClauseProcBindKind::Close; break; - case Fortran::parser::OmpProcBindClause::Type::Spread: + case omp::clause::ProcBind::Type::Spread: procBindKind = mlir::omp::ClauseProcBindKind::Spread; break; - case Fortran::parser::OmpProcBindClause::Type::Primary: + case omp::clause::ProcBind::Type::Primary: procBindKind = mlir::omp::ClauseProcBindKind::Primary; break; } @@ -1314,20 +3422,17 @@ static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr( static mlir::omp::ClauseTaskDependAttr genDependKindAttr(fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::Depend *dependClause) { + const omp::clause::Depend &clause) { mlir::omp::ClauseTaskDepend pbKind; - switch ( - std::get( - std::get(dependClause->v.u) - .t) - .v) { - case Fortran::parser::OmpDependenceType::Type::In: + const auto &inOut = std::get(clause.u); + switch (std::get(inOut.t)) { + case omp::clause::Depend::Type::In: pbKind = mlir::omp::ClauseTaskDepend::taskdependin; break; - case Fortran::parser::OmpDependenceType::Type::Out: + case omp::clause::Depend::Type::Out: pbKind = mlir::omp::ClauseTaskDepend::taskdependout; break; - case Fortran::parser::OmpDependenceType::Type::Inout: + case omp::clause::Depend::Type::Inout: pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; break; default: @@ -1338,45 +3443,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder, pbKind); } -static mlir::Value getIfClauseOperand( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClause::If *ifClause, - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Location clauseLocation) { +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + const omp::clause::If &clause, + omp::clause::If::DirectiveNameModifier directiveName, + mlir::Location clauseLocation) { // Only consider the clause if it's intended for the given directive. - auto &directive = std::get< - std::optional>( - ifClause->v.t); + auto &directive = + std::get>(clause.t); if (directive && directive.value() != directiveName) return nullptr; Fortran::lower::StatementContext stmtCtx; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto &expr = std::get(ifClause->v.t); mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + converter.genExprValue(std::get(clause.t), stmtCtx)); return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), ifVal); } static void addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpObjectList &useDeviceClause, + const omp::ObjectList &objects, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSymbols) { - genObjectList(useDeviceClause, converter, operands); + genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) { checkMapType(operand.getLoc(), operand.getType()); useDeviceTypes.push_back(operand.getType()); useDeviceLocs.push_back(operand.getLoc()); } - for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - useDeviceSymbols.push_back(sym); - } + for (const omp::Object &object : objects) + useDeviceSymbols.push_back(object.sym); } //===----------------------------------------------------------------------===// @@ -1402,9 +3503,8 @@ bool ClauseProcessor::processCollapse( } std::int64_t collapseValue = 1l; - if (auto *collapseClause = findUniqueClause()) { - const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); - collapseValue = Fortran::evaluate::ToInt64(*expr).value(); + if (auto *clause = findUniqueClause()) { + collapseValue = Fortran::evaluate::ToInt64(clause->v).value(); found = true; } @@ -1443,19 +3543,19 @@ bool ClauseProcessor::processCollapse( } bool ClauseProcessor::processDefault() const { - if (auto *defaultClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { // Private, Firstprivate, Shared, None - switch (defaultClause->v.v) { - case Fortran::parser::OmpDefaultClause::Type::Shared: - case Fortran::parser::OmpDefaultClause::Type::None: + switch (clause->v) { + case omp::clause::Default::Type::Shared: + case omp::clause::Default::Type::None: // Default clause with shared or none do not require any handling since // Shared is the default behavior in the IR and None is only required // for semantic checks. break; - case Fortran::parser::OmpDefaultClause::Type::Private: + case omp::clause::Default::Type::Private: // TODO Support default(private) break; - case Fortran::parser::OmpDefaultClause::Type::Firstprivate: + case omp::clause::Default::Type::Firstprivate: // TODO Support default(firstprivate) break; } @@ -1467,20 +3567,17 @@ bool ClauseProcessor::processDefault() const { bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { const Fortran::parser::CharBlock *source = nullptr; - if (auto *deviceClause = findUniqueClause(&source)) { + if (auto *clause = findUniqueClause(&source)) { mlir::Location clauseLocation = converter.genLocation(*source); - if (auto deviceModifier = std::get< - std::optional>( - deviceClause->v.t)) { - if (deviceModifier == - Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { + if (auto deviceModifier = + std::get>( + clause->t)) { + if (deviceModifier == omp::clause::Device::DeviceModifier::Ancestor) { TODO(clauseLocation, "OMPD_target Device Modifier Ancestor"); } } - if (const auto *deviceExpr = Fortran::semantics::GetExpr( - std::get(deviceClause->v.t))) { - result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); - } + const auto &deviceExpr = std::get(clause->t); + result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx)); return true; } return false; @@ -1488,16 +3585,16 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, bool ClauseProcessor::processDeviceType( mlir::omp::DeclareTargetDeviceType &result) const { - if (auto *deviceTypeClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { // Case: declare target ... device_type(any | host | nohost) - switch (deviceTypeClause->v.v) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + switch (clause->v) { + case omp::clause::DeviceType::Type::Nohost: result = mlir::omp::DeclareTargetDeviceType::nohost; break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: + case omp::clause::DeviceType::Type::Host: result = mlir::omp::DeclareTargetDeviceType::host; break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: + case omp::clause::DeviceType::Type::Any: result = mlir::omp::DeclareTargetDeviceType::any; break; } @@ -1509,12 +3606,12 @@ bool ClauseProcessor::processDeviceType( bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { const Fortran::parser::CharBlock *source = nullptr; - if (auto *finalClause = findUniqueClause(&source)) { + if (auto *clause = findUniqueClause(&source)) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location clauseLocation = converter.genLocation(*source); - mlir::Value finalVal = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); + mlir::Value finalVal = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); result = firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), finalVal); return true; @@ -1523,10 +3620,9 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, } bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { - if (auto *hintClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto *expr = Fortran::semantics::GetExpr(hintClause->v); - int64_t hintValue = *Fortran::evaluate::ToInt64(*expr); + int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v); result = firOpBuilder.getI64IntegerAttr(hintValue); return true; } @@ -1534,20 +3630,19 @@ bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { } bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); + return markClauseOccurrence(result); } bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); + return markClauseOccurrence(result); } bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { // TODO Get lower and upper bounds for num_teams when parser is updated to // accept both. - if (auto *numTeamsClause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx)); + if (auto *clause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; @@ -1555,22 +3650,20 @@ bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, bool ClauseProcessor::processNumThreads( Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *numThreadsClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); + result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { - if (auto *orderedClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t orderedClauseValue = 0l; - if (orderedClause->v.has_value()) { - const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); - orderedClauseValue = *Fortran::evaluate::ToInt64(*expr); + if (clause->v.has_value()) { + orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v); } result = firOpBuilder.getI64IntegerAttr(orderedClauseValue); return true; @@ -1580,9 +3673,8 @@ bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *priorityClause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); + if (auto *clause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; @@ -1590,20 +3682,19 @@ bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, bool ClauseProcessor::processProcBind( mlir::omp::ClauseProcBindKindAttr &result) const { - if (auto *procBindClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - result = genProcBindKindAttr(firOpBuilder, procBindClause); + result = genProcBindKindAttr(firOpBuilder, *clause); return true; } return false; } bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const { - if (auto *safelenClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto *expr = Fortran::semantics::GetExpr(safelenClause->v); const std::optional safelenVal = - Fortran::evaluate::ToInt64(*expr); + Fortran::evaluate::ToInt64(clause->v); result = firOpBuilder.getI64IntegerAttr(*safelenVal); return true; } @@ -1614,41 +3705,38 @@ bool ClauseProcessor::processSchedule( mlir::omp::ClauseScheduleKindAttr &valAttr, mlir::omp::ScheduleModifierAttr &modifierAttr, mlir::UnitAttr &simdModifierAttr) const { - if (auto *scheduleClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::MLIRContext *context = firOpBuilder.getContext(); - const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v; - const auto &scheduleClauseKind = - std::get( - scheduleType.t); + const auto &scheduleType = + std::get(clause->t); mlir::omp::ClauseScheduleKind scheduleKind; - switch (scheduleClauseKind) { - case Fortran::parser::OmpScheduleClause::ScheduleType::Static: + switch (scheduleType) { + case omp::clause::Schedule::ScheduleType::Static: scheduleKind = mlir::omp::ClauseScheduleKind::Static; break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: + case omp::clause::Schedule::ScheduleType::Dynamic: scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic; break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: + case omp::clause::Schedule::ScheduleType::Guided: scheduleKind = mlir::omp::ClauseScheduleKind::Guided; break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: + case omp::clause::Schedule::ScheduleType::Auto: scheduleKind = mlir::omp::ClauseScheduleKind::Auto; break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: + case omp::clause::Schedule::ScheduleType::Runtime: scheduleKind = mlir::omp::ClauseScheduleKind::Runtime; break; } - mlir::omp::ScheduleModifier scheduleModifier = - getScheduleModifier(scheduleClause->v); + mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause); if (scheduleModifier != mlir::omp::ScheduleModifier::none) modifierAttr = mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier); - if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none) + if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none) simdModifierAttr = firOpBuilder.getUnitAttr(); valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); @@ -1659,25 +3747,19 @@ bool ClauseProcessor::processSchedule( bool ClauseProcessor::processScheduleChunk( Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *scheduleClause = findUniqueClause()) { - if (const auto &chunkExpr = - std::get>( - scheduleClause->v.t)) { - if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { - result = fir::getBase(converter.genExprValue(*expr, stmtCtx)); - } - } + if (auto *clause = findUniqueClause()) { + if (const auto &chunkExpr = std::get(clause->t)) + result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx)); return true; } return false; } bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { - if (auto *simdlenClause = findUniqueClause()) { + if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v); const std::optional simdlenVal = - Fortran::evaluate::ToInt64(*expr); + Fortran::evaluate::ToInt64(clause->v); result = firOpBuilder.getI64IntegerAttr(*simdlenVal); return true; } @@ -1686,16 +3768,15 @@ bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { bool ClauseProcessor::processThreadLimit( Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *threadLmtClause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); + if (auto *clause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); + return markClauseOccurrence(result); } //===----------------------------------------------------------------------===// @@ -1705,10 +3786,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { bool ClauseProcessor::processAllocate( llvm::SmallVectorImpl &allocatorOperands, llvm::SmallVectorImpl &allocateOperands) const { - return findRepeatableClause( - [&](const ClauseTy::Allocate *allocateClause, + return findRepeatableClause( + [&](const omp::clause::Allocate &clause, const Fortran::parser::CharBlock &) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, + genAllocateClause(converter, clause, allocatorOperands, allocateOperands); }); } @@ -1725,12 +3806,12 @@ bool ClauseProcessor::processCopyin() const { if (converter.isPresentShallowLookup(*sym)) converter.copyHostAssociateVar(*sym, copyAssignIP); }; - bool hasCopyin = findRepeatableClause( - [&](const ClauseTy::Copyin *copyinClause, + bool hasCopyin = findRepeatableClause( + [&](const omp::clause::Copyin &clause, const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const omp::Object &object : clause.v) { + Fortran::semantics::Symbol *sym = object.sym; + assert(sym && "Expecting symbol"); if (const auto *commonDetails = sym->detailsIf()) { for (const auto &mem : commonDetails->objects()) @@ -1763,38 +3844,30 @@ bool ClauseProcessor::processDepend( llvm::SmallVectorImpl &dependOperands) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause( - [&](const ClauseTy::Depend *dependClause, + return findRepeatableClause( + [&](const omp::clause::Depend &clause, const Fortran::parser::CharBlock &) { - const std::list &depVal = - std::get>( - std::get( - dependClause->v.u) - .t); + assert(std::holds_alternative(clause.u) && + "Only InOut is handled at the moment"); + const auto &inOut = std::get(clause.u); + const auto &objects = std::get(inOut.t); + mlir::omp::ClauseTaskDependAttr dependTypeOperand = - genDependKindAttr(firOpBuilder, dependClause); - dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), - dependTypeOperand); - for (const Fortran::parser::Designator &ompObject : depVal) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::DataRef &designator) { - if (const Fortran::parser::Name *name = - std::get_if(&designator.u)) { - sym = name->symbol; - } else if (std::get_if>( - &designator.u)) { - TODO(converter.getCurrentLocation(), - "array sections not supported for task depend"); - } - }, - [&](const Fortran::parser::Substring &designator) { - TODO(converter.getCurrentLocation(), - "substring not supported for task depend"); - }}, - (ompObject).u); + genDependKindAttr(firOpBuilder, clause); + dependTypeOperands.append(objects.size(), dependTypeOperand); + + for (const omp::Object &object : objects) { + assert(object.dsg && "Expecting designator"); + + if (Fortran::evaluate::ExtractSubstring(*object.dsg)) { + TODO(converter.getCurrentLocation(), + "substring not supported for task depend"); + } else if (Fortran::evaluate::IsArrayElement(*object.dsg)) { + TODO(converter.getCurrentLocation(), + "array sections not supported for task depend"); + } + + Fortran::semantics::Symbol *sym = object.sym; const mlir::Value variable = converter.getSymbolAddress(*sym); dependOperands.push_back(variable); } @@ -1802,14 +3875,14 @@ bool ClauseProcessor::processDepend( } bool ClauseProcessor::processIf( - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + omp::clause::If::DirectiveNameModifier directiveName, mlir::Value &result) const { bool found = false; - findRepeatableClause( - [&](const ClauseTy::If *ifClause, + findRepeatableClause( + [&](const omp::clause::If &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); - mlir::Value operand = getIfClauseOperand(converter, ifClause, + mlir::Value operand = getIfClauseOperand(converter, clause, directiveName, clauseLocation); // Assume that, at most, a single 'if' clause will be applicable to the // given directive. @@ -1823,12 +3896,11 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processLink( llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::Link *linkClause, - const Fortran::parser::CharBlock &) { + return findRepeatableClause( + [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) { // Case: declare target link(var1, var2)... gatherFuncAndVarSyms( - linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); + clause.v, mlir::omp::DeclareTargetCaptureClause::link, result); }); } @@ -1865,65 +3937,61 @@ bool ClauseProcessor::processMap( llvm::SmallVectorImpl *mapSymbols) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause( - [&](const ClauseTy::Map *mapClause, + return findRepeatableClause( + [&](const omp::clause::Map &clause, const Fortran::parser::CharBlock &source) { + using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); - const auto &oMapType = - std::get>( - mapClause->v.t); + const auto &oMapType = std::get>(clause.t); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; // If the map type is specified, then process it else Tofrom is the // default. if (oMapType) { - const Fortran::parser::OmpMapType::Type &mapType = - std::get(oMapType->t); + const Map::MapType::Type &mapType = + std::get(oMapType->t); switch (mapType) { - case Fortran::parser::OmpMapType::Type::To: + case Map::MapType::Type::To: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; break; - case Fortran::parser::OmpMapType::Type::From: + case Map::MapType::Type::From: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; - case Fortran::parser::OmpMapType::Type::Tofrom: + case Map::MapType::Type::Tofrom: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; - case Fortran::parser::OmpMapType::Type::Alloc: - case Fortran::parser::OmpMapType::Type::Release: + case Map::MapType::Type::Alloc: + case Map::MapType::Type::Release: // alloc and release is the default map_type for the Target Data // Ops, i.e. if no bits for map_type is supplied then alloc/release // is implicitly assumed based on the target directive. Default // value for Target Data and Enter Data is alloc and for Exit Data // it is release. break; - case Fortran::parser::OmpMapType::Type::Delete: + case Map::MapType::Type::Delete: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; } - if (std::get>( - oMapType->t)) + if (std::get>(oMapType->t)) mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; } else { mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; } - for (const Fortran::parser::OmpObject &ompObject : - std::get(mapClause->v.t).v) { + for (const omp::Object &object : std::get(clause.t)) { llvm::SmallVector bounds; std::stringstream asFortran; Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, - mlir::omp::DataBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); + mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.sym, + object.dsg, clauseLocation, asFortran, bounds, + treatIndexAsSection); - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + auto origSymbol = converter.getSymbolAddress(*object.sym); mlir::Value symAddr = info.addr; if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) symAddr = origSymbol; @@ -1946,7 +4014,7 @@ bool ClauseProcessor::processMap( mapSymLocs->push_back(symAddr.getLoc()); if (mapSymbols) - mapSymbols->push_back(getOmpObjectSymbol(ompObject)); + mapSymbols->push_back(object.sym); } }); } @@ -1957,43 +4025,41 @@ bool ClauseProcessor::processReduction( llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl *reductionSymbols) const { - return findRepeatableClause( - [&](const ClauseTy::Reduction *reductionClause, + return findRepeatableClause( + [&](const omp::clause::Reduction &clause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; - rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols, - reductionSymbols); + rp.addReductionDecl(currentLocation, converter, clause, reductionVars, + reductionDeclSymbols, reductionSymbols); }); } bool ClauseProcessor::processSectionsReduction( mlir::Location currentLocation) const { - return findRepeatableClause( - [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { + return findRepeatableClause( + [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) { TODO(currentLocation, "OMPC_Reduction"); }); } bool ClauseProcessor::processTo( llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { + return findRepeatableClause( + [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) { // Case: declare target to(func, var1, var2)... - gatherFuncAndVarSyms(toClause->v, + gatherFuncAndVarSyms(clause.v, mlir::omp::DeclareTargetCaptureClause::to, result); }); } bool ClauseProcessor::processEnter( llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::Enter *enterClause, + return findRepeatableClause( + [&](const omp::clause::Enter &clause, const Fortran::parser::CharBlock &) { // Case: declare target enter(func, var1, var2)... - gatherFuncAndVarSyms(enterClause->v, - mlir::omp::DeclareTargetCaptureClause::enter, - result); + gatherFuncAndVarSyms( + clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result); }); } @@ -2003,11 +4069,11 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSymbols) const { - return findRepeatableClause( - [&](const ClauseTy::UseDeviceAddr *devAddrClause, + return findRepeatableClause( + [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devAddrClause->v, operands, - useDeviceTypes, useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, + useDeviceLocs, useDeviceSymbols); }); } @@ -2017,10 +4083,10 @@ bool ClauseProcessor::processUseDevicePtr( llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSymbols) const { - return findRepeatableClause( - [&](const ClauseTy::UseDevicePtr *devPtrClause, + return findRepeatableClause( + [&](const omp::clause::UseDevicePtr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, useDeviceLocs, useDeviceSymbols); }); } @@ -2030,31 +4096,30 @@ bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl &mapOperands) { return findRepeatableClause( - [&](const T *motionClause, const Fortran::parser::CharBlock &source) { + [&](const T &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - static_assert(std::is_same_v || - std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); // TODO Support motion modifiers: present, mapper, iterator. constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - std::is_same_v + std::is_same_v ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { + for (const omp::Object &object : clause.v) { llvm::SmallVector bounds; std::stringstream asFortran; Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, - mlir::omp::DataBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); + mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.sym, + object.dsg, clauseLocation, asFortran, bounds, + treatIndexAsSection); - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + auto origSymbol = converter.getSymbolAddress(*object.sym); mlir::Value symAddr = info.addr; if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) symAddr = origSymbol; @@ -2078,19 +4143,17 @@ bool ClauseProcessor::processMotionClauses( template void ClauseProcessor::processTODO(mlir::Location currentLocation, llvm::omp::Directive directive) const { - auto checkUnhandledClause = [&](const auto *x) { + auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) { if (!x) return; TODO(currentLocation, - "Unhandled clause " + - llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x)) - .upper() + + "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + " construct"); }; - for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it) - (checkUnhandledClause(std::get_if(&it->u)), ...); + for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) + (checkUnhandledClause(it->id, std::get_if(&it->u)), ...); } //===----------------------------------------------------------------------===// @@ -2399,7 +4462,7 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) { std::optional tempDsp; if (privatize) { if (!info.dsp) { - tempDsp.emplace(info.converter, *info.clauses, info.eval); + tempDsp.emplace(info.converter, info.semaCtx, *info.clauses, info.eval); tempDsp->processStep1(); } } @@ -2599,7 +4662,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector reductionSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, + cp.processIf(omp::clause::If::DirectiveNameModifier::Parallel, ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); cp.processProcBind(procBindKindAttr); @@ -2669,8 +4732,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processAllocate(allocatorOperands, allocateOperands); - cp.processTODO( - currentLocation, llvm::omp::Directive::OMPD_single); + cp.processTODO(currentLocation, + llvm::omp::Directive::OMPD_single); ClauseProcessor(converter, semaCtx, endClauseList).processNowait(nowaitAttr); @@ -2695,8 +4758,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, dependOperands; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Task, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processFinal(stmtCtx, finalClauseOperand); @@ -2704,10 +4766,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, cp.processMergeable(mergeableAttr); cp.processPriority(stmtCtx, priorityClauseOperand); cp.processDepend(dependTypeOperands, dependOperands); - cp.processTODO( - currentLocation, llvm::omp::Directive::OMPD_task); + cp.processTODO(currentLocation, + llvm::omp::Directive::OMPD_task); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) @@ -2732,7 +4793,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector allocateOperands, allocatorOperands; ClauseProcessor cp(converter, semaCtx, clauseList); cp.processAllocate(allocatorOperands, allocateOperands); - cp.processTODO( + cp.processTODO( currentLocation, llvm::omp::Directive::OMPD_taskgroup); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) @@ -2757,7 +4818,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector useDeviceSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData, + cp.processIf(omp::clause::If::DirectiveNameModifier::TargetData, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, @@ -2788,19 +4849,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, mlir::UnitAttr nowaitAttr; llvm::SmallVector mapOperands; - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName; + omp::clause::If::DirectiveNameModifier directiveName; llvm::omp::Directive directive; if constexpr (std::is_same_v) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData; + directiveName = omp::clause::If::DirectiveNameModifier::TargetEnterData; directive = llvm::omp::Directive::OMPD_target_enter_data; } else if constexpr (std::is_same_v) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData; + directiveName = omp::clause::If::DirectiveNameModifier::TargetExitData; directive = llvm::omp::Directive::OMPD_target_exit_data; } else if constexpr (std::is_same_v) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate; + directiveName = omp::clause::If::DirectiveNameModifier::TargetUpdate; directive = llvm::omp::Directive::OMPD_target_update; } else { return nullptr; @@ -2812,17 +4870,14 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, cp.processNowait(nowaitAttr); if constexpr (std::is_same_v) { - cp.processMotionClauses(stmtCtx, - mapOperands); - cp.processMotionClauses(stmtCtx, - mapOperands); + cp.processMotionClauses(stmtCtx, mapOperands); + cp.processMotionClauses(stmtCtx, mapOperands); } else { cp.processMap(currentLocation, directive, stmtCtx, mapOperands); } - cp.processTODO(currentLocation, - directive); + cp.processTODO(currentLocation, directive); return firOpBuilder.create(currentLocation, ifClauseOperand, deviceOperand, nullptr, mlir::ValueRange(), @@ -2999,23 +5054,17 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector mapSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Target, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processThreadLimit(stmtCtx, threadLimitOperand); cp.processNowait(nowaitAttr); cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols); - cp.processTODO( + cp.processTODO( currentLocation, llvm::omp::Directive::OMPD_target); // 5.8.1 Implicit Data-Mapping Attribute Rules @@ -3113,14 +5162,13 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector reductionDeclSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Teams, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processNumTeams(stmtCtx, numTeamsClauseOperand); cp.processThreadLimit(stmtCtx, threadLimitClauseOperand); - cp.processTODO( - currentLocation, llvm::omp::Directive::OMPD_teams); + cp.processTODO(currentLocation, + llvm::omp::Directive::OMPD_teams); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) @@ -3153,8 +5201,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { + omp::ObjectList objects{omp::makeList(*objectList, semaCtx)}; // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, symbolAndClause); } else if (const auto *clauseList{ Fortran::parser::Unwrap( @@ -3171,7 +5220,7 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); cp.processDeviceType(deviceType); - cp.processTODO( + cp.processTODO( converter.getCurrentLocation(), llvm::omp::Directive::OMPD_declare_target); } @@ -3230,8 +5279,7 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_taskwait: ClauseProcessor(converter, semaCtx, opClauseList) - .processTODO( + .processTODO( currentLocation, llvm::omp::Directive::OMPD_taskwait); firOpBuilder.create(currentLocation); break; @@ -3268,7 +5316,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter, if (const auto &ompObjectList = std::get>( flushConstruct.t)) - genObjectList(*ompObjectList, converter, operandRange); + genObjectList2(*ompObjectList, converter, operandRange); const auto &memOrderClause = std::get>>( flushConstruct.t); @@ -3360,7 +5408,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &loopOpClauseList, mlir::Location loc) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - DataSharingProcessor dsp(converter, loopOpClauseList, eval); + DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; @@ -3378,15 +5426,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); cp.processReduction(loc, reductionVars, reductionDeclSymbols); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Simd, ifClauseOperand); cp.processSimdlen(simdlenClauseOperand); cp.processSafelen(safelenClauseOperand); - cp.processTODO(loc, ompDirective); + cp.processTODO(loc, ompDirective); convertLoopBounds(converter, loc, lowerBound, upperBound, step, loopVarTypeSize); @@ -3420,7 +5465,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - DataSharingProcessor dsp(converter, beginClauseList, eval); + DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; @@ -3441,8 +5486,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); cp.processReduction(loc, reductionVars, reductionDeclSymbols); - cp.processTODO(loc, ompDirective); + cp.processTODO(loc, ompDirective); convertLoopBounds(converter, loc, lowerBound, upperBound, step, loopVarTypeSize); @@ -3502,11 +5546,9 @@ static void createSimdWsLoop( const Fortran::parser::OmpClauseList &beginClauseList, const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processTODO< - Fortran::parser::OmpClause::Aligned, Fortran::parser::OmpClause::Allocate, - Fortran::parser::OmpClause::Linear, Fortran::parser::OmpClause::Safelen, - Fortran::parser::OmpClause::Simdlen, Fortran::parser::OmpClause::Order>( - loc, ompDirective); + cp.processTODO(loc, ompDirective); // TODO: Add support for vectorization - add vectorization hints inside loop // body. // OpenMP standard does not specify the length of vector instructions. @@ -3526,6 +5568,10 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); + // Test call + splitCompositeConstruct(converter.getFirOpBuilder().getModule(), semaCtx, + eval, std::get<0>(beginLoopDirective.t).v, + std::get<1>(beginLoopDirective.t)); const auto &loopOpClauseList = std::get(beginLoopDirective.t); mlir::Location currentLocation = @@ -4173,106 +6219,101 @@ void Fortran::lower::genOpenMPReduction( const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - for (const Fortran::parser::OmpClause &clause : clauseList.v) { + omp::List clauses{omp::makeList(clauseList, semaCtx)}; + + for (const omp::Clause &clause : clauses) { if (const auto &reductionClause = - std::get_if(&clause.u)) { - const auto &redOperator{std::get( - reductionClause->v.t)}; - const auto &objectList{ - std::get(reductionClause->v.t)}; + std::get_if(&clause.u)) { + const auto &redOperator{ + std::get(reductionClause->t)}; + const auto &objectList{std::get(reductionClause->t)}; if (const auto *reductionOp = - std::get_if(&redOperator.u)) { + std::get_if(&redOperator.u)) { const auto &intrinsicOp{ - std::get( + std::get( reductionOp->u)}; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: break; default: continue; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast().getEleTy(); - if (!reductionType.isa()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + mlir::Type reductionType = + reductionVal.getType().cast().getEleTy(); + if (!reductionType.isa()) { + if (!reductionType.isIntOrIndexOrFloat()) + continue; + } + for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (reductionType.isa()) { + mlir::Operation *reductionOp = findReductionChain(loadVal); + fir::ConvertOp convertOp = + getConvertFromReductionOp(reductionOp, loadVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal, &convertOp); + removeStoreOp(reductionOp, reductionVal); + } else if (mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } } } } else if (const auto *reductionIntrinsic = - std::get_if( + std::get_if( &redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) continue; ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; + + if (redId == ReductionProcessor::ReductionIdentifier::MAX || + redId == ReductionProcessor::ReductionIdentifier::MIN) { + assert(mlir::isa(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redId == ReductionProcessor::ReductionIdentifier::IOR || + redId == ReductionProcessor::ReductionIdentifier::IEOR || + redId == ReductionProcessor::ReductionIdentifier::IAND) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index c9358c83e795c..bdae1731260de 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -359,7 +359,6 @@ static mlir::LogicalResult convertFortranSourceToMLIR( semanticsContext.targetCharacteristics(), parsing.allCooked(), targetTriple, kindMap, loweringOptions, {}, semanticsContext.languageFeatures(), targetMachine); - burnside.lower(parseTree, semanticsContext); mlir::ModuleOp mlirModule = burnside.getModule(); if (enableOpenMP) { if (enableOpenMPGPU && !enableOpenMPDevice) { @@ -375,6 +374,7 @@ static mlir::LogicalResult convertFortranSourceToMLIR( setOffloadModuleInterfaceAttributes(mlirModule, offloadModuleOpts); setOpenMPVersionAttribute(mlirModule, setOpenMPVersion); } + burnside.lower(parseTree, semanticsContext); std::error_code ec; std::string outputName = outputFilename; if (!outputName.size()) diff --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td index 31578710365b2..24eb54e75c96b 100644 --- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td +++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td @@ -152,6 +152,10 @@ class Directive { // List of clauses that are required. list requiredClauses = []; + // List of leaf constituent directives in the order in which they appear + // in the combined/composite directive. + list leafs = []; + // Set directive used by default when unknown. bit isDefault = false; } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 1481328bf483b..534ab58985b57 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -773,6 +773,7 @@ def OMP_TargetParallel : Directive<"target parallel"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Parallel]; } def OMP_TargetParallelFor : Directive<"target parallel for"> { let allowedClauses = [ @@ -805,6 +806,7 @@ def OMP_TargetParallelFor : Directive<"target parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Parallel, OMP_For]; } def OMP_TargetParallelDo : Directive<"target parallel do"> { let allowedClauses = [ @@ -835,6 +837,7 @@ def OMP_TargetParallelDo : Directive<"target parallel do"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Target, OMP_Parallel, OMP_Do]; } def OMP_TargetUpdate : Directive<"target update"> { let allowedClauses = [ @@ -848,6 +851,11 @@ def OMP_TargetUpdate : Directive<"target update"> { VersionedClause ]; } +def OMP_masked : Directive<"masked"> { + let allowedOnceClauses = [ + VersionedClause + ]; +} def OMP_ParallelFor : Directive<"parallel for"> { let allowedClauses = [ VersionedClause, @@ -868,6 +876,7 @@ def OMP_ParallelFor : Directive<"parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_For]; } def OMP_ParallelDo : Directive<"parallel do"> { let allowedClauses = [ @@ -889,6 +898,7 @@ def OMP_ParallelDo : Directive<"parallel do"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Parallel, OMP_Do]; } def OMP_ParallelForSimd : Directive<"parallel for simd"> { let allowedClauses = [ @@ -914,6 +924,7 @@ def OMP_ParallelForSimd : Directive<"parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_For, OMP_Simd]; } def OMP_ParallelDoSimd : Directive<"parallel do simd"> { let allowedClauses = [ @@ -940,6 +951,7 @@ def OMP_ParallelDoSimd : Directive<"parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Parallel, OMP_Do, OMP_Simd]; } def OMP_ParallelMaster : Directive<"parallel master"> { let allowedClauses = [ @@ -955,6 +967,7 @@ def OMP_ParallelMaster : Directive<"parallel master"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_Master]; } def OMP_ParallelMasked : Directive<"parallel masked"> { let allowedClauses = [ @@ -971,6 +984,7 @@ def OMP_ParallelMasked : Directive<"parallel masked"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_masked]; } def OMP_ParallelSections : Directive<"parallel sections"> { let allowedClauses = [ @@ -989,6 +1003,7 @@ def OMP_ParallelSections : Directive<"parallel sections"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Parallel, OMP_Sections]; } def OMP_ForSimd : Directive<"for simd"> { let allowedClauses = [ @@ -1009,6 +1024,7 @@ def OMP_ForSimd : Directive<"for simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_For, OMP_Simd]; } def OMP_DoSimd : Directive<"do simd"> { let allowedClauses = [ @@ -1029,6 +1045,7 @@ def OMP_DoSimd : Directive<"do simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Do, OMP_Simd]; } def OMP_CancellationPoint : Directive<"cancellation point"> {} def OMP_DeclareReduction : Directive<"declare reduction"> {} @@ -1106,6 +1123,7 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_TaskLoop, OMP_Simd]; } def OMP_Distribute : Directive<"distribute"> { let allowedClauses = [ @@ -1158,6 +1176,7 @@ def OMP_DistributeParallelFor : Directive<"distribute parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Distribute, OMP_Parallel, OMP_For]; } def OMP_DistributeParallelDo : Directive<"distribute parallel do"> { let allowedClauses = [ @@ -1181,6 +1200,7 @@ def OMP_DistributeParallelDo : Directive<"distribute parallel do"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Distribute, OMP_Parallel, OMP_Do]; } def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> { let allowedClauses = [ @@ -1206,6 +1226,7 @@ def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Distribute, OMP_Parallel, OMP_For, OMP_Simd]; } def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> { let allowedClauses = [ @@ -1230,6 +1251,7 @@ def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Distribute, OMP_Parallel, OMP_Do, OMP_Simd]; } def OMP_DistributeSimd : Directive<"distribute simd"> { let allowedClauses = [ @@ -1256,6 +1278,7 @@ def OMP_DistributeSimd : Directive<"distribute simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Distribute, OMP_Simd]; } def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> { @@ -1293,6 +1316,7 @@ def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Parallel, OMP_For, OMP_Simd]; } def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> { let allowedClauses = [ @@ -1324,6 +1348,7 @@ def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Target, OMP_Parallel, OMP_Do, OMP_Simd]; } def OMP_TargetSimd : Directive<"target simd"> { let allowedClauses = [ @@ -1358,6 +1383,7 @@ def OMP_TargetSimd : Directive<"target simd"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Simd]; } def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedClauses = [ @@ -1377,6 +1403,7 @@ def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedOnceClauses = [ VersionedClause ]; + let leafs = [OMP_Teams, OMP_Distribute]; } def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> { let allowedClauses = [ @@ -1402,6 +1429,7 @@ def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Teams, OMP_Distribute, OMP_Simd]; } def OMP_TeamsDistributeParallelForSimd : @@ -1430,6 +1458,7 @@ def OMP_TeamsDistributeParallelForSimd : VersionedClause, VersionedClause, ]; + let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For, OMP_Simd]; } def OMP_TeamsDistributeParallelDoSimd : Directive<"teams distribute parallel do simd"> { @@ -1458,6 +1487,7 @@ def OMP_TeamsDistributeParallelDoSimd : VersionedClause, VersionedClause ]; + let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do, OMP_Simd]; } def OMP_TeamsDistributeParallelFor : Directive<"teams distribute parallel for"> { @@ -1481,6 +1511,7 @@ def OMP_TeamsDistributeParallelFor : VersionedClause, VersionedClause, ]; + let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For]; } def OMP_TeamsDistributeParallelDo : Directive<"teams distribute parallel do"> { @@ -1507,6 +1538,7 @@ let allowedOnceClauses = [ VersionedClause, VersionedClause ]; + let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do]; } def OMP_TargetTeams : Directive<"target teams"> { let allowedClauses = [ @@ -1534,6 +1566,7 @@ def OMP_TargetTeams : Directive<"target teams"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Teams]; } def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { let allowedClauses = [ @@ -1562,6 +1595,7 @@ def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Teams, OMP_Distribute]; } def OMP_TargetTeamsDistributeParallelFor : @@ -1596,6 +1630,7 @@ def OMP_TargetTeamsDistributeParallelFor : let allowedOnceClauses = [ VersionedClause, ]; + let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For]; } def OMP_TargetTeamsDistributeParallelDo : Directive<"target teams distribute parallel do"> { @@ -1630,6 +1665,7 @@ def OMP_TargetTeamsDistributeParallelDo : VersionedClause, VersionedClause ]; + let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do]; } def OMP_TargetTeamsDistributeParallelForSimd : Directive<"target teams distribute parallel for simd"> { @@ -1668,6 +1704,7 @@ def OMP_TargetTeamsDistributeParallelForSimd : let allowedOnceClauses = [ VersionedClause, ]; + let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For, OMP_Simd]; } def OMP_TargetTeamsDistributeParallelDoSimd : Directive<"target teams distribute parallel do simd"> { @@ -1706,6 +1743,7 @@ def OMP_TargetTeamsDistributeParallelDoSimd : VersionedClause, VersionedClause ]; + let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do, OMP_Simd]; } def OMP_TargetTeamsDistributeSimd : Directive<"target teams distribute simd"> { @@ -1740,6 +1778,7 @@ def OMP_TargetTeamsDistributeSimd : VersionedClause, VersionedClause ]; + let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; } def OMP_Allocate : Directive<"allocate"> { let allowedOnceClauses = [ @@ -1781,6 +1820,7 @@ def OMP_MasterTaskloop : Directive<"master taskloop"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Master, OMP_TaskLoop]; } def OMP_MaskedTaskloop : Directive<"masked taskloop"> { let allowedClauses = [ @@ -1803,6 +1843,7 @@ def OMP_MaskedTaskloop : Directive<"masked taskloop"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_masked, OMP_TaskLoop]; } def OMP_ParallelMasterTaskloop : Directive<"parallel master taskloop"> { @@ -1828,6 +1869,7 @@ def OMP_ParallelMasterTaskloop : VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_Master, OMP_TaskLoop]; } def OMP_ParallelMaskedTaskloop : Directive<"parallel masked taskloop"> { @@ -1854,6 +1896,7 @@ def OMP_ParallelMaskedTaskloop : VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_masked, OMP_TaskLoop]; } def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> { let allowedClauses = [ @@ -1881,6 +1924,7 @@ def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Master, OMP_TaskLoop, OMP_Simd]; } def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> { let allowedClauses = [ @@ -1909,6 +1953,7 @@ def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_masked, OMP_TaskLoop, OMP_Simd]; } def OMP_ParallelMasterTaskloopSimd : Directive<"parallel master taskloop simd"> { @@ -1940,6 +1985,7 @@ def OMP_ParallelMasterTaskloopSimd : VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_Master, OMP_TaskLoop, OMP_Simd]; } def OMP_ParallelMaskedTaskloopSimd : Directive<"parallel masked taskloop simd"> { @@ -1972,6 +2018,7 @@ def OMP_ParallelMaskedTaskloopSimd : VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_masked, OMP_TaskLoop, OMP_Simd]; } def OMP_Depobj : Directive<"depobj"> { let allowedClauses = [ @@ -2003,6 +2050,7 @@ def OMP_scope : Directive<"scope"> { VersionedClause ]; } +def OMP_Workshare : Directive<"workshare"> {} def OMP_ParallelWorkshare : Directive<"parallel workshare"> { let allowedClauses = [ VersionedClause, @@ -2018,8 +2066,8 @@ def OMP_ParallelWorkshare : Directive<"parallel workshare"> { VersionedClause, VersionedClause ]; + let leafs = [OMP_Parallel, OMP_Workshare]; } -def OMP_Workshare : Directive<"workshare"> {} def OMP_EndDo : Directive<"end do"> { let allowedOnceClauses = [ VersionedClause @@ -2069,11 +2117,6 @@ def OMP_dispatch : Directive<"dispatch"> { VersionedClause ]; } -def OMP_masked : Directive<"masked"> { - let allowedOnceClauses = [ - VersionedClause - ]; -} def OMP_loop : Directive<"loop"> { let allowedClauses = [ VersionedClause, @@ -2104,6 +2147,7 @@ def OMP_teams_loop : Directive<"teams loop"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Teams, OMP_loop]; } def OMP_target_teams_loop : Directive<"target teams loop"> { let allowedClauses = [ @@ -2133,6 +2177,7 @@ def OMP_target_teams_loop : Directive<"target teams loop"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Teams, OMP_loop]; } def OMP_parallel_loop : Directive<"parallel loop"> { let allowedClauses = [ @@ -2154,6 +2199,7 @@ def OMP_parallel_loop : Directive<"parallel loop"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Parallel, OMP_loop]; } def OMP_target_parallel_loop : Directive<"target parallel loop"> { let allowedClauses = [ @@ -2185,11 +2231,13 @@ def OMP_target_parallel_loop : Directive<"target parallel loop"> { VersionedClause, VersionedClause, ]; + let leafs = [OMP_Target, OMP_Parallel, OMP_loop]; } def OMP_Metadirective : Directive<"metadirective"> { let allowedClauses = [VersionedClause]; let allowedOnceClauses = [VersionedClause]; } + def OMP_Unknown : Directive<"unknown"> { let isDefault = true; } diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h index c86018715a48a..f655e584f891e 100644 --- a/llvm/include/llvm/TableGen/DirectiveEmitter.h +++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h @@ -121,6 +121,10 @@ class Directive : public BaseRecord { std::vector getRequiredClauses() const { return Def->getValueAsListOfDefs("requiredClauses"); } + + std::vector getLeafConstructs() const { + return Def->getValueAsListOfDefs("leafs"); + } }; // Wrapper class that contains Clause's information defined in DirectiveBase.td diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index b6aee665f8ee0..7cb2a5cbe9595 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -186,6 +186,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { if (DirLang.hasEnableBitmaskEnumInNamespace()) OS << "\n#include \"llvm/ADT/BitmaskEnum.h\"\n"; + OS << "#include \"llvm/ADT/SmallVector.h\"\n"; OS << "\n"; OS << "namespace llvm {\n"; @@ -231,6 +232,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; + OS << "const llvm::SmallVector &getLeafConstructs(Directive D);\n"; if (EnumHelperFuncs.length() > 0) { OS << EnumHelperFuncs; OS << "\n"; @@ -435,6 +437,78 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, OS << "}\n"; // End of function isAllowedClauseForDirective } +// Generate the getLeafConstructs function implementation. +static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang, + raw_ostream &OS) { + auto getQualifiedName = [&](StringRef Formatted) -> std::string { + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + + "::Directive::" + DirLang.getDirectivePrefix() + Formatted) + .str(); + }; + + // For each list of leafs, generate a static local object, then + // return a reference to that object for a given directive, e.g. + // + // static ListTy leafConstructs_A_B = { A, B }; + // static ListTy leafConstructs_C_D_E = { C, D, E }; + // switch (Dir) { + // case A_B: + // return leafConstructs_A_B; + // case C_D_E: + // return leafConstructs_C_D_E; + + // Map from a record that defines a directive to the name of the + // local object with the list of its leafs. + DenseMap ListNames; + + std::string DirectiveTypeName = + std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive"; + std::string DirectiveListTypeName = + std::string("llvm::SmallVector<") + DirectiveTypeName + ">"; + + // const Container &llvm::::GetLeafConstructs(llvm::::Directive Dir) + OS << "const " << DirectiveListTypeName + << " &llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs(" + << DirectiveTypeName << " Dir) "; + OS << "{\n"; + + // Generate the locals. + for (Record *R : DirLang.getDirectives()) { + Directive Dir{R}; + + std::vector LeafConstructs = Dir.getLeafConstructs(); + if (LeafConstructs.empty()) + continue; + + std::string ListName = "leafConstructs_" + Dir.getFormattedName(); + OS << " static " << DirectiveListTypeName << ' ' << ListName << " {\n"; + for (Record *L : LeafConstructs) { + Directive LeafDir{L}; + OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n"; + } + OS << " };\n"; + ListNames.insert(std::make_pair(R, std::move(ListName))); + } + + OS << " static " << DirectiveListTypeName << " nothing {};\n"; + + OS << '\n'; + OS << " switch (Dir) {\n"; + for (Record *R : DirLang.getDirectives()) { + auto F = ListNames.find(R); + if (F == ListNames.end()) + continue; + + Directive Dir{R}; + OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n"; + OS << " return " << F->second << ";\n"; + } + OS << " default:\n"; + OS << " return nothing;\n"; + OS << " } // switch (Dir)\n"; + OS << "}\n"; +} + // Generate a simple enum set with the give clauses. static void GenerateClauseSet(const std::vector &Clauses, raw_ostream &OS, StringRef ClauseSetPrefix, @@ -876,6 +950,9 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang, // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(DirLang, OS); + + // getLeafConstructs(Directive D) + GenerateGetLeafConstructs(DirLang, OS); } // Generate the implemenation section for the enumeration in the directive