@@ -3565,72 +3565,73 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35653565 TODO (converter.getCurrentLocation (), " OmpDeclareVariantDirective" );
35663566}
35673567
3568- static bool
3568+ static ReductionProcessor::GenCombinerCBTy
35693569processReductionCombiner (lower::AbstractConverter &converter,
35703570 lower::SymMap &symTable,
35713571 semantics::SemanticsContext &semaCtx,
3572- const parser::OmpReductionSpecifier &specifier,
3573- ReductionProcessor::GenCombinerCBTy & genCombinerCB) {
3572+ const parser::OmpReductionSpecifier &specifier) {
3573+ ReductionProcessor::GenCombinerCBTy genCombinerCB;
35743574 const auto &combinerExpression =
35753575 std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t )
35763576 .value ();
35773577 const parser::OmpStylizedInstance &combinerInstance =
35783578 combinerExpression.v .front ();
35793579 const parser::OmpStylizedInstance::Instance &instance =
35803580 std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t );
3581- if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u )) {
3582- auto &expr = std::get<parser::Expr>(as->t );
3583- genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3584- mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3585- bool isByRef) {
3586- const auto &evalExpr = makeExpr (expr, semaCtx);
3587- lower::SymMapScope scope (symTable);
3588- const std::list<parser::OmpStylizedDeclaration> &declList =
3589- std::get<std::list<parser::OmpStylizedDeclaration>>(
3590- combinerInstance.t );
3591- for (const parser::OmpStylizedDeclaration &decl : declList) {
3592- auto &name = std::get<parser::ObjectName>(decl.var .t );
3593- mlir::Value addr = lhs;
3594- mlir::Type type = lhs.getType ();
3595- bool isRhs = name.ToString () == std::string (" omp_in" );
3596- if (isRhs) {
3597- addr = rhs;
3598- type = rhs.getType ();
3599- }
36003581
3601- assert (name.symbol && " Reduction object name does not have a symbol" );
3602- if (!fir::conformsWithPassByRef (type)) {
3603- addr = builder.createTemporary (loc, type);
3604- fir::StoreOp::create (builder, loc, isRhs ? rhs : lhs, addr);
3605- }
3606- fir::FortranVariableFlagsEnum extraFlags = {};
3607- fir::FortranVariableFlagsAttr attributes =
3608- Fortran::lower::translateSymbolAttributes (builder.getContext (),
3609- *name.symbol , extraFlags);
3610- auto declareOp = hlfir::DeclareOp::create (
3611- builder, loc, addr, name.ToString (), nullptr , {}, nullptr , nullptr ,
3612- 0 , attributes);
3613- symTable.addVariableDefinition (*name.symbol , declareOp);
3582+ const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u );
3583+ if (!as) {
3584+ TODO (converter.getCurrentLocation (),
3585+ " A combiner that is a subroutine call is not yet supported" );
3586+ }
3587+ auto &expr = std::get<parser::Expr>(as->t );
3588+ genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3589+ mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3590+ bool isByRef) {
3591+ const auto &evalExpr = makeExpr (expr, semaCtx);
3592+ lower::SymMapScope scope (symTable);
3593+ const std::list<parser::OmpStylizedDeclaration> &declList =
3594+ std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t );
3595+ for (const parser::OmpStylizedDeclaration &decl : declList) {
3596+ auto &name = std::get<parser::ObjectName>(decl.var .t );
3597+ mlir::Value addr = lhs;
3598+ mlir::Type type = lhs.getType ();
3599+ bool isRhs = name.ToString () == std::string (" omp_in" );
3600+ if (isRhs) {
3601+ addr = rhs;
3602+ type = rhs.getType ();
36143603 }
36153604
3616- lower::StatementContext stmtCtx;
3617- mlir::Value result = fir::getBase (
3618- convertExprToValue (loc, converter, evalExpr, symTable, stmtCtx));
3619- if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType ()))
3620- if (lhs.getType () == refType.getElementType ())
3621- result = fir::LoadOp::create (builder, loc, result);
3622- stmtCtx.finalizeAndPop ();
3623- if (isByRef) {
3624- fir::StoreOp::create (builder, loc, result, lhs);
3625- mlir::omp::YieldOp::create (builder, loc, lhs);
3626- } else {
3627- mlir::omp::YieldOp::create (builder, loc, result);
3605+ assert (name.symbol && " Reduction object name does not have a symbol" );
3606+ if (!fir::conformsWithPassByRef (type)) {
3607+ addr = builder.createTemporary (loc, type);
3608+ fir::StoreOp::create (builder, loc, isRhs ? rhs : lhs, addr);
36283609 }
3610+ fir::FortranVariableFlagsEnum extraFlags = {};
3611+ fir::FortranVariableFlagsAttr attributes =
3612+ Fortran::lower::translateSymbolAttributes (builder.getContext (),
3613+ *name.symbol , extraFlags);
3614+ auto declareOp =
3615+ hlfir::DeclareOp::create (builder, loc, addr, name.ToString (), nullptr ,
3616+ {}, nullptr , nullptr , 0 , attributes);
3617+ symTable.addVariableDefinition (*name.symbol , declareOp);
3618+ }
36293619
3630- return result;
3631- };
3632- }
3633- return true ;
3620+ lower::StatementContext stmtCtx;
3621+ mlir::Value result = fir::getBase (
3622+ convertExprToValue (loc, converter, evalExpr, symTable, stmtCtx));
3623+ if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType ()))
3624+ if (lhs.getType () == refType.getElementType ())
3625+ result = fir::LoadOp::create (builder, loc, result);
3626+ stmtCtx.finalizeAndPop ();
3627+ if (isByRef) {
3628+ fir::StoreOp::create (builder, loc, result, lhs);
3629+ mlir::omp::YieldOp::create (builder, loc, lhs);
3630+ } else {
3631+ mlir::omp::YieldOp::create (builder, loc, result);
3632+ }
3633+ };
3634+ return genCombinerCB;
36343635}
36353636
36363637// Getting the type from a symbol compared to a DeclSpec is simpler since we do
@@ -3657,45 +3658,43 @@ static void genOMP(
36573658 lower::AbstractConverter &converter, lower::SymMap &symTable,
36583659 semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
36593660 const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
3660- if (!semaCtx.langOptions ().OpenMPSimd ) {
3661- const parser::OmpArgumentList &args{
3662- declareReductionConstruct.v .Arguments ()};
3663- const parser::OmpArgument &arg{args.v .front ()};
3664- const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u );
3665-
3666- if (std::get<parser::OmpTypeNameList>(specifier.t ).v .size () > 1 )
3667- TODO (converter.getCurrentLocation (),
3668- " multiple types in declare reduction is not yet supported" );
3669-
3670- mlir::Type reductionType = getReductionType (converter, specifier);
3671- ReductionProcessor::GenCombinerCBTy genCombinerCB;
3672- processReductionCombiner (converter, symTable, semaCtx, specifier,
3673- genCombinerCB);
3674- const parser::OmpClauseList &initializer =
3675- declareReductionConstruct.v .Clauses ();
3676- if (initializer.v .size () > 0 ) {
3677- List<Clause> clauses = makeClauses (initializer, semaCtx);
3678- ReductionProcessor::GenInitValueCBTy genInitValueCB;
3679- ClauseProcessor cp (converter, semaCtx, clauses);
3680- const parser::OmpClause::Initializer &iclause{
3681- std::get<parser::OmpClause::Initializer>(initializer.v .front ().u )};
3682- cp.processInitializer (symTable, iclause, genInitValueCB);
3683- const auto &identifier =
3684- std::get<parser::OmpReductionIdentifier>(specifier.t );
3685- const auto &designator =
3686- std::get<parser::ProcedureDesignator>(identifier.u );
3687- const auto &reductionName = std::get<parser::Name>(designator.u );
3688- bool isByRef = ReductionProcessor::doReductionByRef (reductionType);
3689- ReductionProcessor::createDeclareReductionHelper<
3690- mlir::omp::DeclareReductionOp>(
3691- converter, reductionName.ToString (), reductionType,
3692- converter.getCurrentLocation (), isByRef, genCombinerCB,
3693- genInitValueCB);
3694- } else {
3695- TODO (converter.getCurrentLocation (),
3696- " declare reduction without an initializer clause is not yet "
3697- " supported" );
3698- }
3661+ if (semaCtx.langOptions ().OpenMPSimd )
3662+ return ;
3663+
3664+ const parser::OmpArgumentList &args{declareReductionConstruct.v .Arguments ()};
3665+ const parser::OmpArgument &arg{args.v .front ()};
3666+ const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u );
3667+
3668+ if (std::get<parser::OmpTypeNameList>(specifier.t ).v .size () > 1 )
3669+ TODO (converter.getCurrentLocation (),
3670+ " multiple types in declare reduction is not yet supported" );
3671+
3672+ mlir::Type reductionType = getReductionType (converter, specifier);
3673+ ReductionProcessor::GenCombinerCBTy genCombinerCB =
3674+ processReductionCombiner (converter, symTable, semaCtx, specifier);
3675+ const parser::OmpClauseList &initializer =
3676+ declareReductionConstruct.v .Clauses ();
3677+ if (initializer.v .size () > 0 ) {
3678+ List<Clause> clauses = makeClauses (initializer, semaCtx);
3679+ ReductionProcessor::GenInitValueCBTy genInitValueCB;
3680+ ClauseProcessor cp (converter, semaCtx, clauses);
3681+ const parser::OmpClause::Initializer &iclause{
3682+ std::get<parser::OmpClause::Initializer>(initializer.v .front ().u )};
3683+ cp.processInitializer (symTable, iclause, genInitValueCB);
3684+ const auto &identifier =
3685+ std::get<parser::OmpReductionIdentifier>(specifier.t );
3686+ const auto &designator =
3687+ std::get<parser::ProcedureDesignator>(identifier.u );
3688+ const auto &reductionName = std::get<parser::Name>(designator.u );
3689+ bool isByRef = ReductionProcessor::doReductionByRef (reductionType);
3690+ ReductionProcessor::createDeclareReductionHelper<
3691+ mlir::omp::DeclareReductionOp>(
3692+ converter, reductionName.ToString (), reductionType,
3693+ converter.getCurrentLocation (), isByRef, genCombinerCB, genInitValueCB);
3694+ } else {
3695+ TODO (converter.getCurrentLocation (),
3696+ " declare reduction without an initializer clause is not yet "
3697+ " supported" );
36993698 }
37003699}
37013700
0 commit comments