Skip to content

Commit 50c7cf9

Browse files
committed
Clean up code and add TODO to address review comments.
1 parent be3bb13 commit 50c7cf9

File tree

1 file changed

+89
-90
lines changed

1 file changed

+89
-90
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 89 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3565,72 +3565,73 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35653565
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
35663566
}
35673567

3568-
static bool
3568+
static ReductionProcessor::GenCombinerCBTy
35693569
processReductionCombiner(lower::AbstractConverter &converter,
35703570
lower::SymMap &symTable,
35713571
semantics::SemanticsContext &semaCtx,
3572-
const parser::OmpReductionSpecifier &specifier,
3573-
ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
3572+
const parser::OmpReductionSpecifier &specifier) {
3573+
ReductionProcessor::GenCombinerCBTy genCombinerCB;
35743574
const auto &combinerExpression =
35753575
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
35763576
.value();
35773577
const parser::OmpStylizedInstance &combinerInstance =
35783578
combinerExpression.v.front();
35793579
const parser::OmpStylizedInstance::Instance &instance =
35803580
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
3581-
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
3582-
auto &expr = std::get<parser::Expr>(as->t);
3583-
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3584-
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3585-
bool isByRef) {
3586-
const auto &evalExpr = makeExpr(expr, semaCtx);
3587-
lower::SymMapScope scope(symTable);
3588-
const std::list<parser::OmpStylizedDeclaration> &declList =
3589-
std::get<std::list<parser::OmpStylizedDeclaration>>(
3590-
combinerInstance.t);
3591-
for (const parser::OmpStylizedDeclaration &decl : declList) {
3592-
auto &name = std::get<parser::ObjectName>(decl.var.t);
3593-
mlir::Value addr = lhs;
3594-
mlir::Type type = lhs.getType();
3595-
bool isRhs = name.ToString() == std::string("omp_in");
3596-
if (isRhs) {
3597-
addr = rhs;
3598-
type = rhs.getType();
3599-
}
36003581

3601-
assert(name.symbol && "Reduction object name does not have a symbol");
3602-
if (!fir::conformsWithPassByRef(type)) {
3603-
addr = builder.createTemporary(loc, type);
3604-
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
3605-
}
3606-
fir::FortranVariableFlagsEnum extraFlags = {};
3607-
fir::FortranVariableFlagsAttr attributes =
3608-
Fortran::lower::translateSymbolAttributes(builder.getContext(),
3609-
*name.symbol, extraFlags);
3610-
auto declareOp = hlfir::DeclareOp::create(
3611-
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
3612-
0, attributes);
3613-
symTable.addVariableDefinition(*name.symbol, declareOp);
3582+
const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
3583+
if (!as) {
3584+
TODO(converter.getCurrentLocation(),
3585+
"A combiner that is a subroutine call is not yet supported");
3586+
}
3587+
auto &expr = std::get<parser::Expr>(as->t);
3588+
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3589+
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3590+
bool isByRef) {
3591+
const auto &evalExpr = makeExpr(expr, semaCtx);
3592+
lower::SymMapScope scope(symTable);
3593+
const std::list<parser::OmpStylizedDeclaration> &declList =
3594+
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
3595+
for (const parser::OmpStylizedDeclaration &decl : declList) {
3596+
auto &name = std::get<parser::ObjectName>(decl.var.t);
3597+
mlir::Value addr = lhs;
3598+
mlir::Type type = lhs.getType();
3599+
bool isRhs = name.ToString() == std::string("omp_in");
3600+
if (isRhs) {
3601+
addr = rhs;
3602+
type = rhs.getType();
36143603
}
36153604

3616-
lower::StatementContext stmtCtx;
3617-
mlir::Value result = fir::getBase(
3618-
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
3619-
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
3620-
if (lhs.getType() == refType.getElementType())
3621-
result = fir::LoadOp::create(builder, loc, result);
3622-
stmtCtx.finalizeAndPop();
3623-
if (isByRef) {
3624-
fir::StoreOp::create(builder, loc, result, lhs);
3625-
mlir::omp::YieldOp::create(builder, loc, lhs);
3626-
} else {
3627-
mlir::omp::YieldOp::create(builder, loc, result);
3605+
assert(name.symbol && "Reduction object name does not have a symbol");
3606+
if (!fir::conformsWithPassByRef(type)) {
3607+
addr = builder.createTemporary(loc, type);
3608+
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
36283609
}
3610+
fir::FortranVariableFlagsEnum extraFlags = {};
3611+
fir::FortranVariableFlagsAttr attributes =
3612+
Fortran::lower::translateSymbolAttributes(builder.getContext(),
3613+
*name.symbol, extraFlags);
3614+
auto declareOp =
3615+
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
3616+
{}, nullptr, nullptr, 0, attributes);
3617+
symTable.addVariableDefinition(*name.symbol, declareOp);
3618+
}
36293619

3630-
return result;
3631-
};
3632-
}
3633-
return true;
3620+
lower::StatementContext stmtCtx;
3621+
mlir::Value result = fir::getBase(
3622+
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
3623+
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
3624+
if (lhs.getType() == refType.getElementType())
3625+
result = fir::LoadOp::create(builder, loc, result);
3626+
stmtCtx.finalizeAndPop();
3627+
if (isByRef) {
3628+
fir::StoreOp::create(builder, loc, result, lhs);
3629+
mlir::omp::YieldOp::create(builder, loc, lhs);
3630+
} else {
3631+
mlir::omp::YieldOp::create(builder, loc, result);
3632+
}
3633+
};
3634+
return genCombinerCB;
36343635
}
36353636

36363637
// Getting the type from a symbol compared to a DeclSpec is simpler since we do
@@ -3657,45 +3658,43 @@ static void genOMP(
36573658
lower::AbstractConverter &converter, lower::SymMap &symTable,
36583659
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
36593660
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
3660-
if (!semaCtx.langOptions().OpenMPSimd) {
3661-
const parser::OmpArgumentList &args{
3662-
declareReductionConstruct.v.Arguments()};
3663-
const parser::OmpArgument &arg{args.v.front()};
3664-
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
3665-
3666-
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
3667-
TODO(converter.getCurrentLocation(),
3668-
"multiple types in declare reduction is not yet supported");
3669-
3670-
mlir::Type reductionType = getReductionType(converter, specifier);
3671-
ReductionProcessor::GenCombinerCBTy genCombinerCB;
3672-
processReductionCombiner(converter, symTable, semaCtx, specifier,
3673-
genCombinerCB);
3674-
const parser::OmpClauseList &initializer =
3675-
declareReductionConstruct.v.Clauses();
3676-
if (initializer.v.size() > 0) {
3677-
List<Clause> clauses = makeClauses(initializer, semaCtx);
3678-
ReductionProcessor::GenInitValueCBTy genInitValueCB;
3679-
ClauseProcessor cp(converter, semaCtx, clauses);
3680-
const parser::OmpClause::Initializer &iclause{
3681-
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
3682-
cp.processInitializer(symTable, iclause, genInitValueCB);
3683-
const auto &identifier =
3684-
std::get<parser::OmpReductionIdentifier>(specifier.t);
3685-
const auto &designator =
3686-
std::get<parser::ProcedureDesignator>(identifier.u);
3687-
const auto &reductionName = std::get<parser::Name>(designator.u);
3688-
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
3689-
ReductionProcessor::createDeclareReductionHelper<
3690-
mlir::omp::DeclareReductionOp>(
3691-
converter, reductionName.ToString(), reductionType,
3692-
converter.getCurrentLocation(), isByRef, genCombinerCB,
3693-
genInitValueCB);
3694-
} else {
3695-
TODO(converter.getCurrentLocation(),
3696-
"declare reduction without an initializer clause is not yet "
3697-
"supported");
3698-
}
3661+
if (semaCtx.langOptions().OpenMPSimd)
3662+
return;
3663+
3664+
const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
3665+
const parser::OmpArgument &arg{args.v.front()};
3666+
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
3667+
3668+
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
3669+
TODO(converter.getCurrentLocation(),
3670+
"multiple types in declare reduction is not yet supported");
3671+
3672+
mlir::Type reductionType = getReductionType(converter, specifier);
3673+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
3674+
processReductionCombiner(converter, symTable, semaCtx, specifier);
3675+
const parser::OmpClauseList &initializer =
3676+
declareReductionConstruct.v.Clauses();
3677+
if (initializer.v.size() > 0) {
3678+
List<Clause> clauses = makeClauses(initializer, semaCtx);
3679+
ReductionProcessor::GenInitValueCBTy genInitValueCB;
3680+
ClauseProcessor cp(converter, semaCtx, clauses);
3681+
const parser::OmpClause::Initializer &iclause{
3682+
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
3683+
cp.processInitializer(symTable, iclause, genInitValueCB);
3684+
const auto &identifier =
3685+
std::get<parser::OmpReductionIdentifier>(specifier.t);
3686+
const auto &designator =
3687+
std::get<parser::ProcedureDesignator>(identifier.u);
3688+
const auto &reductionName = std::get<parser::Name>(designator.u);
3689+
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
3690+
ReductionProcessor::createDeclareReductionHelper<
3691+
mlir::omp::DeclareReductionOp>(
3692+
converter, reductionName.ToString(), reductionType,
3693+
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
3694+
} else {
3695+
TODO(converter.getCurrentLocation(),
3696+
"declare reduction without an initializer clause is not yet "
3697+
"supported");
36993698
}
37003699
}
37013700

0 commit comments

Comments
 (0)