diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 6267231d7fbe2..f6a61ba3a528e 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -685,441 +685,276 @@ static void checkMapType(mlir::Location location, mlir::Type type) { TODO(location, "OMPD_target_data MapOperand BoxType"); } -class ReductionProcessor { -public: - enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR }; - static IntrinsicProc - getReductionType(const Fortran::parser::ProcedureDesignator &pd) { - auto redType = llvm::StringSwitch>( - getRealName(pd).ToString()) - .Case("max", IntrinsicProc::MAX) - .Case("min", IntrinsicProc::MIN) - .Case("iand", IntrinsicProc::IAND) - .Case("ior", IntrinsicProc::IOR) - .Case("ieor", IntrinsicProc::IEOR) - .Default(std::nullopt); - assert(redType && "Invalid Reduction"); - return *redType; - } - - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - auto redType = llvm::StringSwitch>( - getRealName(name).ToString()) - .Case("max", IntrinsicProc::MAX) - .Case("min", IntrinsicProc::MIN) - .Case("iand", IntrinsicProc::IAND) - .Case("ior", IntrinsicProc::IOR) - .Case("ieor", IntrinsicProc::IEOR) - .Default(std::nullopt); - if (redType) - return true; - return false; - } - - static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); - } - - static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); - } - - static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { - return (llvm::Twine(name) + - (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + - llvm::Twine(ty.getIntOrFloatBitWidth())) - .str(); - } - - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { - std::string reductionName; - - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionName = "add_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionName = "multiply_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - return "neqv_reduction"; - default: - reductionName = "other_reduction"; - break; - } - - return getReductionName(reductionName, ty); - } +static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { + return (llvm::Twine(name) + + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + + llvm::Twine(ty.getIntOrFloatBitWidth())) + .str(); +} - /// This function returns the identity value of the operator \p - /// reductionOpName. For example: - /// 0 + x = x, - /// 1 * x = x - static int getOperationIdentity( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Location loc) { - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - return 0; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - return 1; - default: - TODO(loc, "Reduction of some intrinsic operators is not supported"); - } - } +static std::string getReductionName( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { + std::string reductionName; - static mlir::Value getIntrinsicProcInitValue( - mlir::Location loc, mlir::Type type, - const Fortran::parser::ProcedureDesignator &procDesignator, - fir::FirOpBuilder &builder) { - assert((fir::isa_integer(type) || fir::isa_real(type) || - type.isa()) && - "only integer, logical and real types are currently supported"); - switch (getReductionType(procDesignator)) { - case IntrinsicProc::MAX: { - if (auto ty = type.dyn_cast()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); - } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, minInt); - } - case IntrinsicProc::MIN: { - if (auto ty = type.dyn_cast()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); - } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, maxInt); - } - case IntrinsicProc::IOR: { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } - case IntrinsicProc::IEOR: { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } - case IntrinsicProc::IAND: { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, allOnInt); - } - } - llvm_unreachable("Unknown Reduction Intrinsic"); + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionName = "add_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionName = "multiply_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return "and_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return "eqv_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + return "or_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return "neqv_reduction"; + default: + reductionName = "other_reduction"; + break; } - static mlir::Value getIntrinsicOpInitValue( - mlir::Location loc, mlir::Type type, - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - fir::FirOpBuilder &builder) { + return getReductionName(reductionName, ty); +} + +/// This function returns the identity value of the operator \p reductionOpName. +/// For example: +/// 0 + x = x, +/// 1 * x = x +static int getOperationIdentity(llvm::StringRef reductionOpName, + mlir::Location loc) { + if (reductionOpName.contains("add") || reductionOpName.contains("or") || + reductionOpName.contains("neqv")) + return 0; + if (reductionOpName.contains("multiply") || reductionOpName.contains("and") || + reductionOpName.contains("eqv")) + return 1; + TODO(loc, "Reduction of some intrinsic operators is not supported"); +} + +static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, + llvm::StringRef reductionOpName, + fir::FirOpBuilder &builder) { + assert((fir::isa_integer(type) || fir::isa_real(type) || + type.isa()) && + "only integer, logical and real types are currently supported"); + if (reductionOpName.contains("max")) { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, minInt); + } else if (reductionOpName.contains("min")) { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, maxInt); + } else if (reductionOpName.contains("ior")) { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } else if (reductionOpName.contains("ieor")) { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } else if (reductionOpName.contains("iand")) { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, allOnInt); + } else { if (type.isa()) return builder.create( loc, type, - builder.getFloatAttr(type, - (double)getOperationIdentity(intrinsicOp, loc))); + builder.getFloatAttr( + type, (double)getOperationIdentity(reductionOpName, loc))); if (type.isa()) { mlir::Value intConst = builder.create( loc, builder.getI1Type(), builder.getIntegerAttr(builder.getI1Type(), - getOperationIdentity(intrinsicOp, loc))); + getOperationIdentity(reductionOpName, loc))); return builder.createConvert(loc, type, intConst); } return builder.create( loc, type, - builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc))); - } - - template - static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, - mlir::Type type, mlir::Location loc, - mlir::Value op1, mlir::Value op2) { - assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); - if (type.isIntOrIndex()) - return builder.create(loc, op1, op2); - return builder.create(loc, op1, op2); - } - - /// Creates an OpenMP reduction declaration and inserts it into the provided - /// symbol table. The declaration has a constant initializer with the neutral - /// value `initValue`, and the reduction combiner carried over from `reduce`. - /// TODO: Generalize this for non-integer types, add atomic region. - static mlir::omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - const Fortran::parser::ProcedureDesignator &procDesignator, - mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol(reductionOpName); - if (decl) - return decl; + builder.getIntegerAttr(type, + getOperationIdentity(reductionOpName, loc))); + } +} + +template +static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2) { + assert(type.isIntOrIndexOrFloat() && + "only integer and float types are currently supported"); + if (type.isIntOrIndex()) + return builder.create(loc, op1, op2); + return builder.create(loc, op1, op2); +} + +static mlir::omp::ReductionDeclareOp +createMinimalReductionDecl(fir::FirOpBuilder &builder, + llvm::StringRef reductionOpName, mlir::Type type, + mlir::Location loc) { + mlir::ModuleOp module = builder.getModule(); + mlir::OpBuilder modBuilder(module.getBodyRegion()); + + mlir::omp::ReductionDeclareOp decl = + modBuilder.create(loc, reductionOpName, + type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder); + builder.create(loc, init); + + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); + + return decl; +} + +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +/// TODO: Generalize this for non-integer types, add atomic region. +static mlir::omp::ReductionDeclareOp +createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const Fortran::parser::ProcedureDesignator &procDesignator, + mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); + + auto decl = + module.lookupSymbol(reductionOpName); + if (decl) + return decl; - mlir::OpBuilder modBuilder(module.getBodyRegion()); + decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - decl = modBuilder.create( - loc, reductionOpName, type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - mlir::Value init = - getIntrinsicProcInitValue(loc, type, procDesignator, builder); - builder.create(loc, init); - - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); - - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - - mlir::Value reductionOp; - switch (getReductionType(procDesignator)) { - case IntrinsicProc::MAX: + mlir::Value reductionOp; + if (const auto *name{ + Fortran::parser::Unwrap(procDesignator)}) { + if (name->source == "max") { reductionOp = getReductionOperation( builder, type, loc, op1, op2); - break; - case IntrinsicProc::MIN: + } else if (name->source == "min") { reductionOp = getReductionOperation( builder, type, loc, op1, op2); - break; - case IntrinsicProc::IOR: + } else if (name->source == "ior") { assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create(loc, op1, op2); - break; - case IntrinsicProc::IEOR: + } else if (name->source == "ieor") { assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create(loc, op1, op2); - break; - case IntrinsicProc::IAND: + } else if (name->source == "iand") { assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create(loc, op1, op2); - break; - default: - llvm_unreachable( - "Reduction of some intrinsic operators is not supported"); + } else { + TODO(loc, "Reduction of some intrinsic operators is not supported"); } - - builder.create(loc, reductionOp); - return decl; } - /// Creates an OpenMP reduction declaration and inserts it into the provided - /// symbol table. The declaration has a constant initializer with the neutral - /// value `initValue`, and the reduction combiner carried over from `reduce`. - /// TODO: Generalize this for non-integer types, add atomic region. - static mlir::omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol(reductionOpName); - if (decl) - return decl; - - mlir::OpBuilder modBuilder(module.getBodyRegion()); - - decl = modBuilder.create( - loc, reductionOpName, type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder); - builder.create(loc, init); + builder.create(loc, reductionOp); + return decl; +} - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +/// TODO: Generalize this for non-integer types, add atomic region. +static mlir::omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + auto decl = + module.lookupSymbol(reductionOpName); + if (decl) + return decl; - mlir::Value reductionOp; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - mlir::Value andiOp = - builder.create(loc, op1I1, op2I1); + mlir::Value reductionOp; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - reductionOp = builder.createConvert(loc, type, andiOp); - break; - } - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + mlir::Value andiOp = builder.create(loc, op1I1, op2I1); - mlir::Value oriOp = builder.create(loc, op1I1, op2I1); + reductionOp = builder.createConvert(loc, type, andiOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - reductionOp = builder.createConvert(loc, type, oriOp); - break; - } - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + mlir::Value oriOp = builder.create(loc, op1I1, op2I1); - mlir::Value cmpiOp = builder.create( - loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); + reductionOp = builder.createConvert(loc, type, oriOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + mlir::Value cmpiOp = builder.create( + loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); - mlir::Value cmpiOp = builder.create( - loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - default: - TODO(loc, "Reduction of some intrinsic operators is not supported"); - } + mlir::Value cmpiOp = builder.create( + loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); - builder.create(loc, reductionOp); - return decl; + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; } - - /// Creates a reduction declaration and associates it with an OpenMP block - /// directive. - static void addReductionDecl( - mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::omp::ReductionDeclareOp decl; - const auto &redOperator{ - std::get(reduction.t)}; - const auto &objectList{ - std::get(reduction.t)}; - if (const auto &redDefinedOp = - std::get_if(&redOperator.u)) { - const auto &intrinsicOp{ - std::get( - redDefinedOp->u)}; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - break; - - default: - TODO(currentLocation, - "Reduction of some intrinsic operators is not supported"); - break; - } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - if (redType.isa()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - intrinsicOp, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - intrinsicOp, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } else if (const auto *reductionIntrinsic = - std::get_if( - &redOperator.u)) { - if (ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) { - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType), - *reductionIntrinsic, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } - } + default: + TODO(loc, "Reduction of some intrinsic operators is not supported"); } -}; + + builder.create(loc, reductionOp); + return decl; +} static mlir::omp::ScheduleModifier translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) { @@ -1302,6 +1137,101 @@ static mlir::Value getIfClauseOperand( ifVal); } +/// Creates a reduction declaration and associates it with an OpenMP block +/// directive. +static void +addReductionDecl(mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::omp::ReductionDeclareOp decl; + const auto &redOperator{ + std::get(reduction.t)}; + const auto &objectList{std::get(reduction.t)}; + if (const auto &redDefinedOp = + std::get_if(&redOperator.u)) { + const auto &intrinsicOp{ + std::get( + redDefinedOp->u)}; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + break; + + default: + TODO(currentLocation, + "Reduction of some intrinsic operators is not supported"); + break; + } + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), + intrinsicOp, redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + intrinsicOp, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } else if (const auto *reductionIntrinsic = + std::get_if( + &redOperator.u)) { + if (const auto *name{Fortran::parser::Unwrap( + reductionIntrinsic)}) { + if ((name->source != "max") && (name->source != "min") && + (name->source != "ior") && (name->source != "ieor") && + (name->source != "iand")) { + TODO(currentLocation, + "Reduction of intrinsic procedures is not supported"); + } + std::string intrinsicOp = name->ToString(); + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, getReductionName(intrinsicOp, redType), + *reductionIntrinsic, redType, currentLocation); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } + } +} + static void addUseDeviceClause(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpObjectList &useDeviceClause, @@ -1898,9 +1828,8 @@ bool ClauseProcessor::processReduction( return findRepeatableClause( [&](const ClauseTy::Reduction *reductionClause, const Fortran::parser::CharBlock &) { - ReductionProcessor rp; - rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); + addReductionDecl(currentLocation, converter, reductionClause->v, + reductionVars, reductionDeclSymbols); }); } @@ -3736,50 +3665,48 @@ void Fortran::lower::genOpenMPReduction( } else if (const auto *reductionIntrinsic = std::get_if( &redOperator.u)) { - if (!ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) - continue; - ReductionProcessor::IntrinsicProc redIntrinsicProc = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redIntrinsicProc == - ReductionProcessor::IntrinsicProc::MAX || - redIntrinsicProc == - ReductionProcessor::IntrinsicProc::MIN) { - assert(mlir::isa(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redIntrinsicProc == - ReductionProcessor::IntrinsicProc::IOR || - redIntrinsicProc == - ReductionProcessor::IntrinsicProc::IEOR || - redIntrinsicProc == - ReductionProcessor::IntrinsicProc::IAND) { + if (const auto *name{Fortran::parser::Unwrap( + reductionIntrinsic)}) { + std::string redName = name->ToString(); + if ((name->source != "max") && (name->source != "min") && + (name->source != "ior") && (name->source != "ieor") && + (name->source != "iand")) { + continue; + } + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{Fortran::parser::Unwrap( + ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = + reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = mlir::dyn_cast( + reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; + + if (redName == "max" || redName == "min") { + assert(mlir::isa(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redName == "ior" || redName == "ieor" || + redName == "iand") { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); + } } } }