diff --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h index 4007bfc7994f9..1d3a85e250073 100644 --- a/flang/include/flang/Common/Fortran.h +++ b/flang/include/flang/Common/Fortran.h @@ -87,6 +87,10 @@ ENUM_CLASS(CUDASubprogramAttrs, Host, Device, HostDevice, Global, Grid_Global) // CUDA data attributes; mutually exclusive ENUM_CLASS(CUDADataAttr, Constant, Device, Managed, Pinned, Shared, Texture) +// OpenACC device types +ENUM_CLASS( + OpenACCDeviceType, Star, Default, Nvidia, Radeon, Host, Multicore, None) + // OpenMP atomic_default_mem_order clause allowed values ENUM_CLASS(OmpAtomicDefaultMemOrderType, SeqCst, AcqRel, Relaxed) diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h index 1defbf132327c..d067a7273540f 100644 --- a/flang/include/flang/Parser/dump-parse-tree.h +++ b/flang/include/flang/Parser/dump-parse-tree.h @@ -48,6 +48,7 @@ class ParseTreeDumper { NODE(std, uint64_t) NODE_ENUM(common, CUDADataAttr) NODE_ENUM(common, CUDASubprogramAttrs) + NODE_ENUM(common, OpenACCDeviceType) NODE(format, ControlEditDesc) NODE(format::ControlEditDesc, Kind) NODE(format, DerivedTypeDataEditDesc) @@ -101,7 +102,7 @@ class ParseTreeDumper { NODE(parser, AccSelfClause) NODE(parser, AccStandaloneDirective) NODE(parser, AccDeviceTypeExpr) - NODE_ENUM(parser::AccDeviceTypeExpr, Device) + NODE(parser, AccDeviceTypeExprList) NODE(parser, AccTileExpr) NODE(parser, AccTileExprList) diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index 71195f2bb9ddc..e9bfb728a2bef 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -4072,8 +4072,8 @@ struct AccWaitArgument { }; struct AccDeviceTypeExpr { - ENUM_CLASS(Device, Star, Default, Nvidia, Radeon, Host, Multicore) - WRAPPER_CLASS_BOILERPLATE(AccDeviceTypeExpr, Device); + WRAPPER_CLASS_BOILERPLATE( + AccDeviceTypeExpr, Fortran::common::OpenACCDeviceType); CharBlock source; }; diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h index f6f195b6bb95b..1f5d84e2b6b37 100644 --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -112,7 +112,8 @@ class WithBindName { bool isExplicitBindName_{false}; }; -class OpenACCRoutineInfo { +// Device type specific OpenACC routine information +class OpenACCRoutineDeviceTypeInfo { public: bool isSeq() const { return isSeq_; } void set_isSeq(bool value = true) { isSeq_ = value; } @@ -124,12 +125,12 @@ class OpenACCRoutineInfo { void set_isGang(bool value = true) { isGang_ = value; } unsigned gangDim() const { return gangDim_; } void set_gangDim(unsigned value) { gangDim_ = value; } - bool isNohost() const { return isNohost_; } - void set_isNohost(bool value = true) { isNohost_ = value; } const std::string *bindName() const { return bindName_ ? &*bindName_ : nullptr; } void set_bindName(std::string &&name) { bindName_ = std::move(name); } + void set_dType(Fortran::common::OpenACCDeviceType dType) { dType_ = dType; } + Fortran::common::OpenACCDeviceType dType() const { return dType_; } private: bool isSeq_{false}; @@ -137,8 +138,28 @@ class OpenACCRoutineInfo { bool isWorker_{false}; bool isGang_{false}; unsigned gangDim_{0}; - bool isNohost_{false}; std::optional bindName_; + Fortran::common::OpenACCDeviceType dType_{ + Fortran::common::OpenACCDeviceType::None}; +}; + +// OpenACC routine information. Device independent info are stored on the +// OpenACCRoutineInfo instance while device dependent info are stored +// as objects in the OpenACCRoutineDeviceTypeInfo list. +class OpenACCRoutineInfo : public OpenACCRoutineDeviceTypeInfo { +public: + bool isNohost() const { return isNohost_; } + void set_isNohost(bool value = true) { isNohost_ = value; } + std::list &deviceTypeInfos() { + return deviceTypeInfos_; + } + void add_deviceTypeInfo(OpenACCRoutineDeviceTypeInfo &info) { + deviceTypeInfos_.push_back(info); + } + +private: + std::list deviceTypeInfos_; + bool isNohost_{false}; }; // A subroutine or function definition, or a subprogram interface defined diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index d24c369d81bed..db9ed72bc8725 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1483,20 +1483,22 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter, } static mlir::acc::DeviceType -getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) { +getDeviceType(Fortran::common::OpenACCDeviceType device) { switch (device) { - case Fortran::parser::AccDeviceTypeExpr::Device::Star: + case Fortran::common::OpenACCDeviceType::Star: return mlir::acc::DeviceType::Star; - case Fortran::parser::AccDeviceTypeExpr::Device::Default: + case Fortran::common::OpenACCDeviceType::Default: return mlir::acc::DeviceType::Default; - case Fortran::parser::AccDeviceTypeExpr::Device::Nvidia: + case Fortran::common::OpenACCDeviceType::Nvidia: return mlir::acc::DeviceType::Nvidia; - case Fortran::parser::AccDeviceTypeExpr::Device::Radeon: + case Fortran::common::OpenACCDeviceType::Radeon: return mlir::acc::DeviceType::Radeon; - case Fortran::parser::AccDeviceTypeExpr::Device::Host: + case Fortran::common::OpenACCDeviceType::Host: return mlir::acc::DeviceType::Host; - case Fortran::parser::AccDeviceTypeExpr::Device::Multicore: + case Fortran::common::OpenACCDeviceType::Multicore: return mlir::acc::DeviceType::Multicore; + case Fortran::common::OpenACCDeviceType::None: + return mlir::acc::DeviceType::None; } return mlir::acc::DeviceType::None; } diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp index 5b9267e0e17c6..946b33d0084a9 100644 --- a/flang/lib/Parser/openacc-parsers.cpp +++ b/flang/lib/Parser/openacc-parsers.cpp @@ -54,13 +54,13 @@ TYPE_PARSER(construct(scalarIntExpr) || TYPE_PARSER(construct(nonemptyList(Parser{}))) TYPE_PARSER(sourced(construct( - first("*" >> pure(AccDeviceTypeExpr::Device::Star), - "DEFAULT" >> pure(AccDeviceTypeExpr::Device::Default), - "NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia), - "ACC_DEVICE_NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia), - "RADEON" >> pure(AccDeviceTypeExpr::Device::Radeon), - "HOST" >> pure(AccDeviceTypeExpr::Device::Host), - "MULTICORE" >> pure(AccDeviceTypeExpr::Device::Multicore))))) + first("*" >> pure(Fortran::common::OpenACCDeviceType::Star), + "DEFAULT" >> pure(Fortran::common::OpenACCDeviceType::Default), + "NVIDIA" >> pure(Fortran::common::OpenACCDeviceType::Nvidia), + "ACC_DEVICE_NVIDIA" >> pure(Fortran::common::OpenACCDeviceType::Nvidia), + "RADEON" >> pure(Fortran::common::OpenACCDeviceType::Radeon), + "HOST" >> pure(Fortran::common::OpenACCDeviceType::Host), + "MULTICORE" >> pure(Fortran::common::OpenACCDeviceType::Multicore))))) TYPE_PARSER( construct(nonemptyList(Parser{}))) diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 70b6bbf8b557a..e87125580cc1d 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -491,31 +491,51 @@ void ModFileWriter::PutDECStructure( static const Attrs subprogramPrefixAttrs{Attr::ELEMENTAL, Attr::IMPURE, Attr::MODULE, Attr::NON_RECURSIVE, Attr::PURE, Attr::RECURSIVE}; +static void PutOpenACCDeviceTypeRoutineInfo( + llvm::raw_ostream &os, const OpenACCRoutineDeviceTypeInfo &info) { + if (info.isSeq()) { + os << " seq"; + } + if (info.isGang()) { + os << " gang"; + if (info.gangDim() > 0) { + os << "(dim: " << info.gangDim() << ")"; + } + } + if (info.isVector()) { + os << " vector"; + } + if (info.isWorker()) { + os << " worker"; + } + if (info.bindName()) { + os << " bind(" << *info.bindName() << ")"; + } +} + static void PutOpenACCRoutineInfo( llvm::raw_ostream &os, const SubprogramDetails &details) { for (auto info : details.openACCRoutineInfos()) { os << "!$acc routine"; - if (info.isSeq()) { - os << " seq"; - } - if (info.isGang()) { - os << " gang"; - if (info.gangDim() > 0) { - os << "(dim: " << info.gangDim() << ")"; - } - } - if (info.isVector()) { - os << " vector"; - } - if (info.isWorker()) { - os << " worker"; - } + + PutOpenACCDeviceTypeRoutineInfo(os, info); + if (info.isNohost()) { os << " nohost"; } - if (info.bindName()) { - os << " bind(" << *info.bindName() << ")"; + + for (auto dtype : info.deviceTypeInfos()) { + os << " device_type("; + if (dtype.dType() == common::OpenACCDeviceType::Star) { + os << "*"; + } else { + os << parser::ToLowerCaseLetters(common::EnumToString(dtype.dType())); + } + os << ")"; + + PutOpenACCDeviceTypeRoutineInfo(os, dtype); } + os << "\n"; } } diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index da6c865ad56a3..d7b13631ab4df 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -945,25 +945,45 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol( const auto &clauses = std::get(x.t); for (const Fortran::parser::AccClause &clause : clauses.v) { if (std::get_if(&clause.u)) { - info.set_isSeq(); + if (info.deviceTypeInfos().empty()) { + info.set_isSeq(); + } else { + info.deviceTypeInfos().back().set_isSeq(); + } } else if (const auto *gangClause = std::get_if(&clause.u)) { - info.set_isGang(); + if (info.deviceTypeInfos().empty()) { + info.set_isGang(); + } else { + info.deviceTypeInfos().back().set_isGang(); + } if (gangClause->v) { const Fortran::parser::AccGangArgList &x = *gangClause->v; for (const Fortran::parser::AccGangArg &gangArg : x.v) { if (const auto *dim = std::get_if(&gangArg.u)) { if (const auto v{EvaluateInt64(context_, dim->v)}) { - info.set_gangDim(*v); + if (info.deviceTypeInfos().empty()) { + info.set_gangDim(*v); + } else { + info.deviceTypeInfos().back().set_gangDim(*v); + } } } } } } else if (std::get_if(&clause.u)) { - info.set_isVector(); + if (info.deviceTypeInfos().empty()) { + info.set_isVector(); + } else { + info.deviceTypeInfos().back().set_isVector(); + } } else if (std::get_if(&clause.u)) { - info.set_isWorker(); + if (info.deviceTypeInfos().empty()) { + info.set_isWorker(); + } else { + info.deviceTypeInfos().back().set_isWorker(); + } } else if (std::get_if(&clause.u)) { info.set_isNohost(); } else if (const auto *bindClause = @@ -971,7 +991,12 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol( if (const auto *name = std::get_if(&bindClause->v.u)) { if (Symbol *sym = ResolveFctName(*name)) { - info.set_bindName(sym->name().ToString()); + if (info.deviceTypeInfos().empty()) { + info.set_bindName(sym->name().ToString()); + } else { + info.deviceTypeInfos().back().set_bindName( + sym->name().ToString()); + } } else { context_.Say((*name).source, "No function or subroutine declared for '%s'"_err_en_US, @@ -986,8 +1011,19 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol( std::string str{std::get(charConst->t)}; std::stringstream bindName; bindName << "\"" << str << "\""; - info.set_bindName(bindName.str()); + if (info.deviceTypeInfos().empty()) { + info.set_bindName(bindName.str()); + } else { + info.deviceTypeInfos().back().set_bindName(bindName.str()); + } } + } else if (const auto *dType = + std::get_if( + &clause.u)) { + const parser::AccDeviceTypeExprList &deviceTypeExprList = dType->v; + OpenACCRoutineDeviceTypeInfo dtypeInfo; + dtypeInfo.set_dType(deviceTypeExprList.v.front().v); + info.add_deviceTypeInfo(dtypeInfo); } } symbol.get().add_openACCRoutineInfo(info); diff --git a/flang/test/Semantics/OpenACC/acc-module.f90 b/flang/test/Semantics/OpenACC/acc-module.f90 index f552816d69882..7f034d8ae54f0 100644 --- a/flang/test/Semantics/OpenACC/acc-module.f90 +++ b/flang/test/Semantics/OpenACC/acc-module.f90 @@ -60,6 +60,13 @@ subroutine sub9() subroutine sub10() end subroutine + subroutine sub11() + !$acc routine device_type(nvidia) gang device_type(*) seq + end subroutine + + subroutine sub12() + !$acc routine device_type(host) bind(sub7) device_type(multicore) bind(sub8) + end subroutine end module !Expect: acc_mod.mod @@ -107,4 +114,10 @@ subroutine sub10() ! subroutine sub10() ! !$acc routine seq ! end +! subroutinesub11() +! !$acc routine device_type(nvidia) gang device_type(*) seq +! end +! subroutinesub12() +! !$acc routine device_type(host) bind(sub7) device_type(multicore) bind(sub8) +! end ! end