Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][Lower] Convert OMP Map and related functions to evaluate::Expr #81626

Merged
merged 7 commits into from
Mar 20, 2024
8 changes: 8 additions & 0 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); }
std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&);
std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &);

// Propagate std::optional from input to output.
template <typename A>
std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) {
if (!x)
return std::nullopt;
return AsGenericExpr(std::move(*x));
}

template <typename A>
common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
A &&x) {
Expand Down
389 changes: 234 additions & 155 deletions flang/lib/Lower/DirectivesCommon.h

Large diffs are not rendered by default.

54 changes: 35 additions & 19 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
Fortran::parser::GetLastName(arrayElement->base);
return *name.symbol;
}
if (const auto *component =
Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
*designator)) {
return *component->component.symbol;
}
} else if (const auto *name =
std::get_if<Fortran::parser::Name>(&accObject.u)) {
return *name->symbol;
Expand All @@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataClause dataClause, bool structured,
bool implicit, bool setDeclareAttr = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds,
/*treatIndexAsSection=*/true);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds,
/*treatIndexAsSection=*/true);

// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
Expand All @@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations(
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType());
Expand All @@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations(
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
modBuilder.setInsertionPointAfter(builder.getFunction());
std::string prefix =
converter.mangleName(getSymbolFromAccObject(accObject));
std::string prefix = converter.mangleName(symbol);
createDeclareAllocFuncWithArg<EntryOp>(
modBuilder, builder, operandLocation, info.addr.getType(), prefix,
asFortran, dataClause);
Expand Down Expand Up @@ -770,16 +780,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
llvm::SmallVector<mlir::Attribute> &privatizations) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
Expand Down Expand Up @@ -1340,16 +1353,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
const auto &op =
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objects.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);

mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
Expand Down
44 changes: 20 additions & 24 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,65 +818,61 @@ bool ClauseProcessor::processMap(
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause2<ClauseTy::Map>(
[&](const ClauseTy::Map *mapClause,
return findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
const Fortran::parser::CharBlock &source) {
using Map = omp::clause::Map;
mlir::Location clauseLocation = converter.genLocation(source);
const auto &oMapType =
std::get<std::optional<Fortran::parser::OmpMapType>>(
mapClause->v.t);
const auto &oMapType = std::get<std::optional<Map::MapType>>(clause.t);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
// If the map type is specified, then process it else Tofrom is the
// default.
if (oMapType) {
const Fortran::parser::OmpMapType::Type &mapType =
std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
const Map::MapType::Type &mapType =
std::get<Map::MapType::Type>(oMapType->t);
switch (mapType) {
case Fortran::parser::OmpMapType::Type::To:
case Map::MapType::Type::To:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
break;
case Fortran::parser::OmpMapType::Type::From:
case Map::MapType::Type::From:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
case Fortran::parser::OmpMapType::Type::Tofrom:
case Map::MapType::Type::Tofrom:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
case Fortran::parser::OmpMapType::Type::Alloc:
case Fortran::parser::OmpMapType::Type::Release:
case Map::MapType::Type::Alloc:
case Map::MapType::Type::Release:
// alloc and release is the default map_type for the Target Data
// Ops, i.e. if no bits for map_type is supplied then alloc/release
// is implicitly assumed based on the target directive. Default
// value for Target Data and Enter Data is alloc and for Exit Data
// it is release.
break;
case Fortran::parser::OmpMapType::Type::Delete:
case Map::MapType::Type::Delete:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
}

if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
oMapType->t))
if (std::get<std::optional<Map::MapType::Always>>(oMapType->t))
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
} else {
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
}

for (const Fortran::parser::OmpObject &ompObject :
std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;

Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::OmpObject, mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
clauseLocation, asFortran, bounds, treatIndexAsSection);
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

auto origSymbol =
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
auto origSymbol = converter.getSymbolAddress(*object.id());
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
Expand All @@ -899,7 +895,7 @@ bool ClauseProcessor::processMap(
mapSymLocs->push_back(symAddr.getLoc());

if (mapSymbols)
mapSymbols->push_back(getOmpObjectSymbol(ompObject));
mapSymbols->push_back(object.id());
}
});
}
Expand Down
59 changes: 11 additions & 48 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ class ClauseProcessor {
/// Utility to find a clause within a range in the clause list.
template <typename T>
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
template <typename T>
static ClauseIterator2 findClause2(ClauseIterator2 begin,
ClauseIterator2 end);

/// Return the first instance of the given clause found in the clause list or
/// `nullptr` if not present. If more than one instance is expected, use
Expand All @@ -179,10 +176,6 @@ class ClauseProcessor {
bool findRepeatableClause(
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
callbackFn) const;
template <typename T>
bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const;

/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
template <typename T>
Expand All @@ -198,32 +191,31 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
return findRepeatableClause2<T>(
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
return findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
static_assert(std::is_same_v<T, omp::clause::To> ||
std::is_same_v<T, omp::clause::From>);

// TODO Support motion modifiers: present, mapper, iterator.
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
std::is_same_v<T, ClauseProcessor::ClauseTy::To>
std::is_same_v<T, omp::clause::To>
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;

for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
for (const omp::Object &object : clause.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::OmpObject, mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
clauseLocation, asFortran, bounds, treatIndexAsSection);
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

auto origSymbol =
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
auto origSymbol = converter.getSymbolAddress(*object.id());
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
Expand Down Expand Up @@ -273,17 +265,6 @@ ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
return end;
}

template <typename T>
ClauseProcessor::ClauseIterator2
ClauseProcessor::findClause2(ClauseIterator2 begin, ClauseIterator2 end) {
for (ClauseIterator2 it = begin; it != end; ++it) {
if (std::get_if<T>(&it->u))
return it;
}

return end;
}

template <typename T>
const T *ClauseProcessor::findUniqueClause(
const Fortran::parser::CharBlock **source) const {
Expand Down Expand Up @@ -314,24 +295,6 @@ bool ClauseProcessor::findRepeatableClause(
return found;
}

template <typename T>
bool ClauseProcessor::findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
ClauseIterator2 nextIt, endIt = clauses2.v.end();
for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) {
nextIt = findClause2<T>(it, endIt);

if (nextIt != endIt) {
callbackFn(&std::get<T>(nextIt->u), nextIt->source);
found = true;
++nextIt;
}
}
return found;
}

template <typename T>
bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
if (findUniqueClause<T>()) {
Expand Down
7 changes: 2 additions & 5 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,11 +930,8 @@ static OpTy genTargetEnterExitDataUpdateOp(
cp.processNowait(nowaitAttr);

if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx,
mapOperands);
cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx,
mapOperands);

cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
} else {
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenACC/acc-bounds.f90
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ subroutine acc_optional_data3(a, n)
! CHECK: fir.result %c0{{.*}} : index
! CHECK: }
! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {

end module
Loading
Loading