diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 9a32062440abc0..8c872a0579c8ed 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -148,6 +148,14 @@ inline Expr AsGenericExpr(Expr &&x) { return std::move(x); } std::optional> AsGenericExpr(DataRef &&); std::optional> AsGenericExpr(const Symbol &); +// Propagate std::optional from input to output. +template +std::optional> AsGenericExpr(std::optional &&x) { + if (!x) + return std::nullopt; + return AsGenericExpr(std::move(*x)); +} + template common::IfNoLvalue::category>>, A> AsCategoryExpr( A &&x) { diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h index 8d560db34e05bf..6daa72b84d90d4 100644 --- a/flang/lib/Lower/DirectivesCommon.h +++ b/flang/lib/Lower/DirectivesCommon.h @@ -808,6 +808,75 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, return bounds; } +namespace detail { +template // +static T &&AsRvalueRef(T &&t) { + return std::move(t); +} +template // +static T AsRvalueRef(T &t) { + return t; +} +template // +static T AsRvalueRef(const T &t) { + return t; +} + +// Helper class for stripping enclosing parentheses and a conversion that +// preserves type category. This is used for triplet elements, which are +// always of type integer(kind=8). The lower/upper bounds are converted to +// an "index" type, which is 64-bit, so the explicit conversion to kind=8 +// (if present) is not needed. When it's present, though, it causes generated +// names to contain "int(..., kind=8)". +struct PeelConvert { + template + static Fortran::semantics::MaybeExpr visit_with_category( + const Fortran::evaluate::Expr> + &expr) { + return std::visit( + [](auto &&s) { return visit_with_category(s); }, + expr.u); + } + template + static Fortran::semantics::MaybeExpr visit_with_category( + const Fortran::evaluate::Convert, + Category> &expr) { + return AsGenericExpr(AsRvalueRef(expr.left())); + } + template + static Fortran::semantics::MaybeExpr visit_with_category(const T &) { + return std::nullopt; // + } + template + static Fortran::semantics::MaybeExpr visit_with_category(const T &) { + return std::nullopt; // + } + + template + static Fortran::semantics::MaybeExpr + visit(const Fortran::evaluate::Expr> + &expr) { + return std::visit([](auto &&s) { return visit_with_category(s); }, + expr.u); + } + static Fortran::semantics::MaybeExpr + visit(const Fortran::evaluate::Expr &expr) { + return std::visit([](auto &&s) { return visit(s); }, expr.u); + } + template // + static Fortran::semantics::MaybeExpr visit(const T &) { + return std::nullopt; + } +}; + +static Fortran::semantics::SomeExpr +peelOuterConvert(Fortran::semantics::SomeExpr &expr) { + if (auto peeled = PeelConvert::visit(expr)) + return *peeled; + return expr; +} +} // namespace detail + /// Generate bounds operations for an array section when subscripts are /// provided. template @@ -815,7 +884,7 @@ llvm::SmallVector genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx, - const std::list &subscripts, + const std::vector &subscripts, std::stringstream &asFortran, fir::ExtendedValue &dataExv, bool dataExvIsAssumedSize, AddrAndBoundsInfo &info, bool treatIndexAsSection = false) { @@ -828,8 +897,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); const int dataExvRank = static_cast(dataExv.rank()); for (const auto &subscript : subscripts) { - const auto *triplet{ - std::get_if(&subscript.u)}; + const auto *triplet{std::get_if(&subscript.u)}; if (triplet || treatIndexAsSection) { if (dimension != 0) asFortran << ','; @@ -868,13 +936,17 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, strideInBytes = true; } - const Fortran::lower::SomeExpr *lower{nullptr}; + Fortran::semantics::MaybeExpr lower; if (triplet) { - if (const auto &tripletLb{std::get<0>(triplet->t)}) - lower = Fortran::semantics::GetExpr(*tripletLb); + lower = Fortran::evaluate::AsGenericExpr(triplet->lower()); } else { - const auto &index{std::get(subscript.u)}; - lower = Fortran::semantics::GetExpr(index); + // Case of IndirectSubscriptIntegerExpr + using IndirectSubscriptIntegerExpr = + Fortran::evaluate::IndirectSubscriptIntegerExpr; + using SubscriptInteger = Fortran::evaluate::SubscriptInteger; + Fortran::evaluate::Expr oneInt = + std::get(subscript.u).value(); + lower = Fortran::evaluate::AsGenericExpr(std::move(oneInt)); if (lower->Rank() > 0) { mlir::emitError( loc, "vector subscript cannot be used for an array section"); @@ -896,7 +968,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, fir::getBase(converter.genExprValue(loc, *lower, stmtCtx)); lb = builder.createConvert(loc, baseLb.getType(), lb); lbound = builder.create(loc, lb, baseLb); - asFortran << lower->AsFortran(); + asFortran << detail::peelOuterConvert(*lower).AsFortran(); } } else { // If the lower bound is not specified, then the section @@ -912,10 +984,11 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, extent = one; } else { asFortran << ':'; - const auto &upper{std::get<1>(triplet->t)}; + Fortran::semantics::MaybeExpr upper = + Fortran::evaluate::AsGenericExpr(triplet->upper()); if (upper) { - uval = Fortran::semantics::GetIntValue(upper); + uval = Fortran::evaluate::ToInt64(*upper); if (uval) { if (defaultLb) { ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1); @@ -925,22 +998,21 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, } asFortran << *uval; } else { - const Fortran::lower::SomeExpr *uexpr = - Fortran::semantics::GetExpr(*upper); mlir::Value ub = - fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx)); + fir::getBase(converter.genExprValue(loc, *upper, stmtCtx)); ub = builder.createConvert(loc, baseLb.getType(), ub); ubound = builder.create(loc, ub, baseLb); - asFortran << uexpr->AsFortran(); + asFortran << detail::peelOuterConvert(*upper).AsFortran(); } } if (lower && upper) { if (lval && uval && *uval < *lval) { mlir::emitError(loc, "zero sized array section"); break; - } else if (std::get<2>(triplet->t)) { - const auto &strideExpr{std::get<2>(triplet->t)}; - if (strideExpr) { + } else { + // Stride is mandatory in evaluate::Triplet. Make sure it's 1. + auto val = Fortran::evaluate::ToInt64(triplet->GetStride()); + if (!val || *val != 1) { mlir::emitError(loc, "stride cannot be specified on " "an array section"); break; @@ -993,150 +1065,157 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, return bounds; } -template +namespace detail { +template // +std::optional getRef(Expr &&expr) { + if constexpr (std::is_same_v, + Fortran::evaluate::DataRef>) { + if (auto *ref = std::get_if(&expr.u)) + return *ref; + return std::nullopt; + } else { + auto maybeRef = Fortran::evaluate::ExtractDataRef(expr); + if (!maybeRef || !std::holds_alternative(maybeRef->u)) + return std::nullopt; + return std::get(maybeRef->u); + } +} +} // namespace detail + +template AddrAndBoundsInfo gatherDataOperandAddrAndBounds( Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, const ObjectType &object, + semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::semantics::SymbolRef symbol, + const Fortran::semantics::MaybeExpr &maybeDesignator, mlir::Location operandLocation, std::stringstream &asFortran, llvm::SmallVector &bounds, bool treatIndexAsSection = false) { + using namespace Fortran; + AddrAndBoundsInfo info; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext, - designator)}) { - if (((*expr).Rank() > 0 || treatIndexAsSection) && - Fortran::parser::Unwrap( - designator)) { - const auto *arrayElement = - Fortran::parser::Unwrap( - designator); - const auto *dataRef = - std::get_if(&designator.u); - fir::ExtendedValue dataExv; - bool dataExvIsAssumedSize = false; - if (Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>( - arrayElement->base)) { - auto exprBase = Fortran::semantics::AnalyzeExpr( - semanticsContext, arrayElement->base); - dataExv = converter.genExprAddr(operandLocation, *exprBase, - stmtCtx); - info.addr = fir::getBase(dataExv); - info.rawInput = info.addr; - asFortran << (*exprBase).AsFortran(); - } else { - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); - dataExvIsAssumedSize = Fortran::semantics::IsAssumedSizeArray( - name.symbol->GetUltimate()); - info = getDataOperandBaseAddr(converter, builder, - *name.symbol, operandLocation); - dataExv = converter.getSymbolExtendedValue(*name.symbol); - asFortran << name.ToString(); - } - - if (!arrayElement->subscripts.empty()) { - asFortran << '('; - bounds = genBoundsOps( - builder, operandLocation, converter, stmtCtx, - arrayElement->subscripts, asFortran, dataExv, - dataExvIsAssumedSize, info, treatIndexAsSection); - } - asFortran << ')'; - } else if (auto structComp = Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>(designator)) { - fir::ExtendedValue compExv = - converter.genExprAddr(operandLocation, *expr, stmtCtx); - info.addr = fir::getBase(compExv); - info.rawInput = info.addr; - if (fir::unwrapRefType(info.addr.getType()) - .isa()) - bounds = genBaseBoundsOps( - builder, operandLocation, converter, compExv, - /*isAssumedSize=*/false); - asFortran << (*expr).AsFortran(); - - bool isOptional = Fortran::semantics::IsOptional( - *Fortran::parser::GetLastName(*structComp).symbol); - if (isOptional) - info.isPresent = builder.create( - operandLocation, builder.getI1Type(), info.rawInput); - - if (auto loadOp = mlir::dyn_cast_or_null( - info.addr.getDefiningOp())) { - if (fir::isAllocatableType(loadOp.getType()) || - fir::isPointerType(loadOp.getType())) - info.addr = builder.create(operandLocation, - info.addr); - info.rawInput = info.addr; - } - - // If the component is an allocatable or pointer the result of - // genExprAddr will be the result of a fir.box_addr operation or - // a fir.box_addr has been inserted just before. - // Retrieve the box so we handle it like other descriptor. - if (auto boxAddrOp = mlir::dyn_cast_or_null( - info.addr.getDefiningOp())) { - info.addr = boxAddrOp.getVal(); - info.rawInput = info.addr; - bounds = genBoundsOpsFromBox( - builder, operandLocation, converter, compExv, info); - } - } else { - if (Fortran::parser::Unwrap( - designator)) { - // Single array element. - const auto *arrayElement = - Fortran::parser::Unwrap( - designator); - (void)arrayElement; - fir::ExtendedValue compExv = - converter.genExprAddr(operandLocation, *expr, stmtCtx); - info.addr = fir::getBase(compExv); - info.rawInput = info.addr; - asFortran << (*expr).AsFortran(); - } else if (const auto *dataRef{ - std::get_if( - &designator.u)}) { - // Scalar or full array. - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); - fir::ExtendedValue dataExv = - converter.getSymbolExtendedValue(*name.symbol); - info = getDataOperandBaseAddr(converter, builder, - *name.symbol, operandLocation); - if (fir::unwrapRefType(info.addr.getType()) - .isa()) { - bounds = genBoundsOpsFromBox( - builder, operandLocation, converter, dataExv, info); - } - bool dataExvIsAssumedSize = - Fortran::semantics::IsAssumedSizeArray( - name.symbol->GetUltimate()); - if (fir::unwrapRefType(info.addr.getType()) - .isa()) - bounds = genBaseBoundsOps( - builder, operandLocation, converter, dataExv, - dataExvIsAssumedSize); - asFortran << name.ToString(); - } else { // Unsupported - llvm::report_fatal_error( - "Unsupported type of OpenACC operand"); - } - } - } - }, - [&](const Fortran::parser::Name &name) { - info = getDataOperandBaseAddr(converter, builder, *name.symbol, - operandLocation); - asFortran << name.ToString(); - }}, - object.u); + + if (!maybeDesignator) { + info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation); + asFortran << symbol->name().ToString(); + return info; + } + + semantics::SomeExpr designator = *maybeDesignator; + + if ((designator.Rank() > 0 || treatIndexAsSection) && + IsArrayElement(designator)) { + auto arrayRef = detail::getRef(designator); + // This shouldn't fail after IsArrayElement(designator). + assert(arrayRef && "Expecting ArrayRef"); + + fir::ExtendedValue dataExv; + bool dataExvIsAssumedSize = false; + + auto toMaybeExpr = [&](auto &&base) { + using BaseType = llvm::remove_cvref_t; + evaluate::ExpressionAnalyzer ea{semaCtx}; + + if constexpr (std::is_same_v) { + if (auto *ref = base.UnwrapSymbolRef()) + return ea.Designate(evaluate::DataRef{*ref}); + if (auto *ref = base.UnwrapComponent()) + return ea.Designate(evaluate::DataRef{*ref}); + llvm_unreachable("Unexpected NamedEntity"); + } else { + static_assert(std::is_same_v); + return ea.Designate(evaluate::DataRef{base}); + } + }; + + auto arrayBase = toMaybeExpr(arrayRef->base()); + assert(arrayBase); + + if (detail::getRef(*arrayBase)) { + dataExv = converter.genExprAddr(operandLocation, *arrayBase, stmtCtx); + info.addr = fir::getBase(dataExv); + info.rawInput = info.addr; + asFortran << arrayBase->AsFortran(); + } else { + const semantics::Symbol &sym = arrayRef->GetLastSymbol(); + dataExvIsAssumedSize = + Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate()); + info = getDataOperandBaseAddr(converter, builder, sym, operandLocation); + dataExv = converter.getSymbolExtendedValue(sym); + asFortran << sym.name().ToString(); + } + + if (!arrayRef->subscript().empty()) { + asFortran << '('; + bounds = genBoundsOps( + builder, operandLocation, converter, stmtCtx, arrayRef->subscript(), + asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection); + } + asFortran << ')'; + } else if (auto compRef = detail::getRef(designator)) { + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, designator, stmtCtx); + info.addr = fir::getBase(compExv); + info.rawInput = info.addr; + if (fir::unwrapRefType(info.addr.getType()).isa()) + bounds = genBaseBoundsOps(builder, operandLocation, + converter, compExv, + /*isAssumedSize=*/false); + asFortran << designator.AsFortran(); + + if (semantics::IsOptional(compRef->GetLastSymbol())) { + info.isPresent = builder.create( + operandLocation, builder.getI1Type(), info.rawInput); + } + + if (auto loadOp = + mlir::dyn_cast_or_null(info.addr.getDefiningOp())) { + if (fir::isAllocatableType(loadOp.getType()) || + fir::isPointerType(loadOp.getType())) + info.addr = builder.create(operandLocation, info.addr); + info.rawInput = info.addr; + } + + // If the component is an allocatable or pointer the result of + // genExprAddr will be the result of a fir.box_addr operation or + // a fir.box_addr has been inserted just before. + // Retrieve the box so we handle it like other descriptor. + if (auto boxAddrOp = + mlir::dyn_cast_or_null(info.addr.getDefiningOp())) { + info.addr = boxAddrOp.getVal(); + info.rawInput = info.addr; + bounds = genBoundsOpsFromBox( + builder, operandLocation, converter, compExv, info); + } + } else { + if (detail::getRef(designator)) { + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, designator, stmtCtx); + info.addr = fir::getBase(compExv); + info.rawInput = info.addr; + asFortran << designator.AsFortran(); + } else if (auto symRef = detail::getRef(designator)) { + // Scalar or full array. + fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef); + info = + getDataOperandBaseAddr(converter, builder, *symRef, operandLocation); + if (fir::unwrapRefType(info.addr.getType()).isa()) { + bounds = genBoundsOpsFromBox( + builder, operandLocation, converter, dataExv, info); + } + bool dataExvIsAssumedSize = + Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate()); + if (fir::unwrapRefType(info.addr.getType()).isa()) + bounds = genBaseBoundsOps( + builder, operandLocation, converter, dataExv, dataExvIsAssumedSize); + asFortran << symRef->get().name().ToString(); + } else { // Unsupported + llvm::report_fatal_error("Unsupported type of OpenACC operand"); + } + } + return info; } - } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 6539de4d88304c..7b7e4a875cd8e8 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { Fortran::parser::GetLastName(arrayElement->base); return *name.symbol; } + if (const auto *component = + Fortran::parser::Unwrap( + *designator)) { + return *component->component.symbol; + } } else if (const auto *name = std::get_if(&accObject.u)) { return *name->symbol; @@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, mlir::acc::DataClause dataClause, bool structured, bool implicit, bool setDeclareAttr = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds, - /*treatIndexAsSection=*/true); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds, + /*treatIndexAsSection=*/true); // If the input value is optional and is not a descriptor, we use the // rawInput directly. @@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations( llvm::SmallVectorImpl &dataOperands, mlir::acc::DataClause dataClause, bool structured, bool implicit) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); EntryOp op = createDataEntryOp( builder, operandLocation, info.addr, asFortran, bounds, structured, implicit, dataClause, info.addr.getType()); @@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations( if (mlir::isa(fir::unwrapRefType(info.addr.getType()))) { mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); modBuilder.setInsertionPointAfter(builder.getFunction()); - std::string prefix = - converter.mangleName(getSymbolFromAccObject(accObject)); + std::string prefix = converter.mangleName(symbol); createDeclareAllocFuncWithArg( modBuilder, builder, operandLocation, info.addr.getType(), prefix, asFortran, dataClause); @@ -770,16 +780,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList, llvm::SmallVectorImpl &dataOperands, llvm::SmallVector &privatizations) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); RecipeOp recipe; mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType()); if constexpr (std::is_same_v) { @@ -1340,16 +1353,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, const auto &op = std::get(objectList.t); mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objects.v) { llvm::SmallVector bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType()); if (auto seqTy = mlir::dyn_cast(reductionTy)) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index ae798b5c0a58fa..95faa0767e3656 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -818,65 +818,61 @@ bool ClauseProcessor::processMap( llvm::SmallVectorImpl *mapSymbols) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause2( - [&](const ClauseTy::Map *mapClause, + return findRepeatableClause( + [&](const omp::clause::Map &clause, const Fortran::parser::CharBlock &source) { + using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); - const auto &oMapType = - std::get>( - mapClause->v.t); + const auto &oMapType = std::get>(clause.t); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; // If the map type is specified, then process it else Tofrom is the // default. if (oMapType) { - const Fortran::parser::OmpMapType::Type &mapType = - std::get(oMapType->t); + const Map::MapType::Type &mapType = + std::get(oMapType->t); switch (mapType) { - case Fortran::parser::OmpMapType::Type::To: + case Map::MapType::Type::To: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; break; - case Fortran::parser::OmpMapType::Type::From: + case Map::MapType::Type::From: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; - case Fortran::parser::OmpMapType::Type::Tofrom: + case Map::MapType::Type::Tofrom: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; - case Fortran::parser::OmpMapType::Type::Alloc: - case Fortran::parser::OmpMapType::Type::Release: + case Map::MapType::Type::Alloc: + case Map::MapType::Type::Release: // alloc and release is the default map_type for the Target Data // Ops, i.e. if no bits for map_type is supplied then alloc/release // is implicitly assumed based on the target directive. Default // value for Target Data and Enter Data is alloc and for Exit Data // it is release. break; - case Fortran::parser::OmpMapType::Type::Delete: + case Map::MapType::Type::Delete: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; } - if (std::get>( - oMapType->t)) + if (std::get>(oMapType->t)) mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; } else { mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; } - for (const Fortran::parser::OmpObject &ompObject : - std::get(mapClause->v.t).v) { + for (const omp::Object &object : std::get(clause.t)) { llvm::SmallVector bounds; std::stringstream asFortran; Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::MapBoundsOp, - mlir::omp::MapBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); + mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.id(), + object.ref(), clauseLocation, asFortran, bounds, + treatIndexAsSection); - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + auto origSymbol = converter.getSymbolAddress(*object.id()); mlir::Value symAddr = info.addr; if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) symAddr = origSymbol; @@ -899,7 +895,7 @@ bool ClauseProcessor::processMap( mapSymLocs->push_back(symAddr.getLoc()); if (mapSymbols) - mapSymbols->push_back(getOmpObjectSymbol(ompObject)); + mapSymbols->push_back(object.id()); } }); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index c2db0cfc3cb7bd..ffa8a5e0559385 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -162,9 +162,6 @@ class ClauseProcessor { /// Utility to find a clause within a range in the clause list. template static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); - template - static ClauseIterator2 findClause2(ClauseIterator2 begin, - ClauseIterator2 end); /// Return the first instance of the given clause found in the clause list or /// `nullptr` if not present. If more than one instance is expected, use @@ -179,10 +176,6 @@ class ClauseProcessor { bool findRepeatableClause( std::function callbackFn) const; - template - bool findRepeatableClause2( - std::function - callbackFn) const; /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. template @@ -198,32 +191,31 @@ template bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl &mapOperands) { - return findRepeatableClause2( - [&](const T *motionClause, const Fortran::parser::CharBlock &source) { + return findRepeatableClause( + [&](const T &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - static_assert(std::is_same_v || - std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); // TODO Support motion modifiers: present, mapper, iterator. constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - std::is_same_v + std::is_same_v ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { + for (const omp::Object &object : clause.v) { llvm::SmallVector bounds; std::stringstream asFortran; Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::MapBoundsOp, - mlir::omp::MapBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); + mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.id(), + object.ref(), clauseLocation, asFortran, bounds, + treatIndexAsSection); - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + auto origSymbol = converter.getSymbolAddress(*object.id()); mlir::Value symAddr = info.addr; if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) symAddr = origSymbol; @@ -273,17 +265,6 @@ ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { return end; } -template -ClauseProcessor::ClauseIterator2 -ClauseProcessor::findClause2(ClauseIterator2 begin, ClauseIterator2 end) { - for (ClauseIterator2 it = begin; it != end; ++it) { - if (std::get_if(&it->u)) - return it; - } - - return end; -} - template const T *ClauseProcessor::findUniqueClause( const Fortran::parser::CharBlock **source) const { @@ -314,24 +295,6 @@ bool ClauseProcessor::findRepeatableClause( return found; } -template -bool ClauseProcessor::findRepeatableClause2( - std::function - callbackFn) const { - bool found = false; - ClauseIterator2 nextIt, endIt = clauses2.v.end(); - for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) { - nextIt = findClause2(it, endIt); - - if (nextIt != endIt) { - callbackFn(&std::get(nextIt->u), nextIt->source); - found = true; - ++nextIt; - } - } - return found; -} - template bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { if (findUniqueClause()) { diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 5c4caa1de57382..d335129565b4e2 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -930,11 +930,8 @@ static OpTy genTargetEnterExitDataUpdateOp( cp.processNowait(nowaitAttr); if constexpr (std::is_same_v) { - cp.processMotionClauses(stmtCtx, - mapOperands); - cp.processMotionClauses(stmtCtx, - mapOperands); - + cp.processMotionClauses(stmtCtx, mapOperands); + cp.processMotionClauses(stmtCtx, mapOperands); } else { cp.processMap(currentLocation, directive, stmtCtx, mapOperands); } diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90 index df97cbcd187d2b..c275d4f1b1d5f2 100644 --- a/flang/test/Lower/OpenACC/acc-bounds.f90 +++ b/flang/test/Lower/OpenACC/acc-bounds.f90 @@ -184,7 +184,7 @@ subroutine acc_optional_data3(a, n) ! CHECK: fir.result %c0{{.*}} : index ! CHECK: } ! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true} -! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%14) -> !fir.ref> {name = "a(1:n)"} +! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%[[BOUNDS]]) -> !fir.ref> {name = "a(1:n)"} ! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref>) { end module diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 index 2cf50c1b62f190..251edbf9c2dd0a 100644 --- a/flang/test/Lower/OpenACC/acc-enter-data.f90 +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -234,11 +234,13 @@ subroutine acc_enter_data_dummy(a, b, n, m) !$acc enter data create(b(n:m)) !CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECLB]]#0, %c0{{.*}} : (!fir.box>, index) -> (index, index, index) !CHECK: %[[LOAD_N:.*]] = fir.load %[[DECLN]]#0 : !fir.ref -!CHECK: %[[N_CONV:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index -!CHECK: %[[LB:.*]] = arith.subi %[[N_CONV]], %[[N_IDX]] : index +!CHECK: %[[N_CONV1:.*]] = fir.convert %[[LOAD_N]] : (i32) -> i64 +!CHECK: %[[N_CONV2:.*]] = fir.convert %[[N_CONV1]] : (i64) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[N_CONV2]], %[[N_IDX]] : index !CHECK: %[[LOAD_M:.*]] = fir.load %[[DECLM]]#0 : !fir.ref -!CHECK: %[[M_CONV:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index -!CHECK: %[[UB:.*]] = arith.subi %[[M_CONV]], %[[N_IDX]] : index +!CHECK: %[[M_CONV1:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64 +!CHECK: %[[M_CONV2:.*]] = fir.convert %[[M_CONV1]] : (i64) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[M_CONV2]], %[[N_IDX]] : index !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXT_B]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[N_IDX]] : index) {strideInBytes = true} !CHECK: %[[ADDR:.*]] = fir.box_addr %[[DECLB]]#0 : (!fir.box>) -> !fir.ref> !CHECK: %[[CREATE1:.*]] = acc.create varPtr(%[[ADDR]] : !fir.ref>) bounds(%[[BOUND1]]) -> !fir.ref> {name = "b(n:m)", structured = false} @@ -248,8 +250,9 @@ subroutine acc_enter_data_dummy(a, b, n, m) !CHECK: %[[ONE:.*]] = arith.constant 1 : index !CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECLB]]#0, %c0_8 : (!fir.box>, index) -> (index, index, index) !CHECK: %[[LOAD_N:.*]] = fir.load %[[DECLN]]#0 : !fir.ref -!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index -!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[N_IDX]] : index +!CHECK: %[[CONVERT1_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> i64 +!CHECK: %[[CONVERT2_N:.*]] = fir.convert %[[CONVERT1_N]] : (i64) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT2_N]], %[[N_IDX]] : index !CHECK: %[[UB:.*]] = arith.subi %[[EXT_B]], %c1{{.*}} : index !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXT_B]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[N_IDX]] : index) {strideInBytes = true} !CHECK: %[[ADDR:.*]] = fir.box_addr %[[DECLB]]#0 : (!fir.box>) -> !fir.ref> @@ -424,8 +427,9 @@ subroutine acc_enter_data_assumed(a, b, n, m) !CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECLA]]#0, %[[C0]] : (!fir.box>, index) -> (index, index, index) !CHECK: %[[LOAD_N:.*]] = fir.load %[[DECLN]]#0 : !fir.ref -!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index -!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[ONE]] : index +!CHECK: %[[CONVERT1_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> i64 +!CHECK: %[[CONVERT2_N:.*]] = fir.convert %[[CONVERT1_N]] : (i64) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT2_N]], %[[ONE]] : index !CHECK: %[[C0:.*]] = arith.constant 0 : index !CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[DECLA]]#1, %[[C0]] : (!fir.box>, index) -> (index, index, index) @@ -444,8 +448,9 @@ subroutine acc_enter_data_assumed(a, b, n, m) !CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECLA]]#0, %[[C0]] : (!fir.box>, index) -> (index, index, index) !CHECK: %[[LOAD_M:.*]] = fir.load %[[DECLM]]#0 : !fir.ref -!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index -!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[ONE]] : index +!CHECK: %[[CONVERT1_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64 +!CHECK: %[[CONVERT2_M:.*]] = fir.convert %[[CONVERT1_M]] : (i64) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT2_M]], %[[ONE]] : index !CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[DECLA]]#1, %{{.*}} : (!fir.box>, index) -> (index, index, index) !CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[BASELB]] : index) upperbound(%[[UB]] : index) extent(%[[DIMS1]]#1 : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} @@ -460,12 +465,14 @@ subroutine acc_enter_data_assumed(a, b, n, m) !CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECLA]]#0, %[[C0]] : (!fir.box>, index) -> (index, index, index) !CHECK: %[[LOAD_N:.*]] = fir.load %[[DECLN]]#0 : !fir.ref -!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index -!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[ONE]] : index +!CHECK: %[[CONVERT1_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> i64 +!CHECK: %[[CONVERT2_N:.*]] = fir.convert %[[CONVERT1_N]] : (i64) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT2_N]], %[[ONE]] : index !CHECK: %[[LOAD_M:.*]] = fir.load %[[DECLM]]#0 : !fir.ref -!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index -!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[ONE]] : index +!CHECK: %[[CONVERT1_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64 +!CHECK: %[[CONVERT2_M:.*]] = fir.convert %[[CONVERT1_M]] : (i64) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT2_M]], %[[ONE]] : index !CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[DECLA]]#1, %{{.*}} : (!fir.box>, index) -> (index, index, index) !CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[DIMS1]]#1 : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} @@ -480,8 +487,9 @@ subroutine acc_enter_data_assumed(a, b, n, m) !CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECLB]]#0, %[[C0]] : (!fir.box>, index) -> (index, index, index) !CHECK: %[[LOAD_M:.*]] = fir.load %[[DECLM]]#0 : !fir.ref -!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index -!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[LB_C10_IDX]] : index +!CHECK: %[[CONVERT1_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64 +!CHECK: %[[CONVERT2_M:.*]] = fir.convert %[[CONVERT1_M]] : (i64) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT2_M]], %[[LB_C10_IDX]] : index !CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[DECLB]]#1, %{{.*}} : (!fir.box>, index) -> (index, index, index) !CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[ZERO]] : index) upperbound(%[[UB]] : index) extent(%[[DIMS1]]#1 : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[LB_C10_IDX]] : index) {strideInBytes = true}