diff --git a/flang/include/flang/Common/visit.h b/flang/include/flang/Common/visit.h index d3136be3f6a1f..a6f04c1f27497 100644 --- a/flang/include/flang/Common/visit.h +++ b/flang/include/flang/Common/visit.h @@ -42,6 +42,7 @@ #ifndef FORTRAN_COMMON_VISIT_H_ #define FORTRAN_COMMON_VISIT_H_ +#include "flang/Runtime/api-attrs.h" #include #include @@ -50,7 +51,7 @@ namespace log2visit { template -inline RESULT Log2VisitHelper( +inline RT_API_ATTRS RESULT Log2VisitHelper( VISITOR &&visitor, std::size_t which, VARIANT &&...u) { if constexpr (LOW == HIGH) { return visitor(std::get(std::forward(u))...); @@ -67,7 +68,7 @@ inline RESULT Log2VisitHelper( } template -inline auto visit(VISITOR &&visitor, VARIANT &&...u) +inline RT_API_ATTRS auto visit(VISITOR &&visitor, VARIANT &&...u) -> decltype(visitor(std::get<0>(std::forward(u))...)) { using Result = decltype(visitor(std::get<0>(std::forward(u))...)); if constexpr (sizeof...(u) == 1) { diff --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp index 075d7b5ae518a..c42af0aeeb9a3 100644 --- a/flang/runtime/io-stmt.cpp +++ b/flang/runtime/io-stmt.cpp @@ -467,69 +467,67 @@ int ExternalFormattedIoStatementState::EndIoStatement() { } Fortran::common::optional IoStatementState::GetNextDataEdit(int n) { - return common::visit( + return visitIo( [&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_); } bool IoStatementState::Emit( const char *data, std::size_t bytes, std::size_t elementBytes) { - return common::visit( + return visitIo( [=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_); } bool IoStatementState::Receive( char *data, std::size_t n, std::size_t elementBytes) { - return common::visit( + return visitIo( [=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_); } std::size_t IoStatementState::GetNextInputBytes(const char *&p) { - return common::visit( - [&](auto &x) { return x.get().GetNextInputBytes(p); }, u_); + return visitIo([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_); } bool IoStatementState::AdvanceRecord(int n) { - return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_); + return visitIo([=](auto &x) { return x.get().AdvanceRecord(n); }, u_); } void IoStatementState::BackspaceRecord() { - common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_); + visitIo([](auto &x) { x.get().BackspaceRecord(); }, u_); } void IoStatementState::HandleRelativePosition(std::int64_t n) { - common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_); + visitIo([=](auto &x) { x.get().HandleRelativePosition(n); }, u_); } void IoStatementState::HandleAbsolutePosition(std::int64_t n) { - common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_); + visitIo([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_); } void IoStatementState::CompleteOperation() { - common::visit([](auto &x) { x.get().CompleteOperation(); }, u_); + visitIo([](auto &x) { x.get().CompleteOperation(); }, u_); } int IoStatementState::EndIoStatement() { - return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_); + return visitIo([](auto &x) { return x.get().EndIoStatement(); }, u_); } ConnectionState &IoStatementState::GetConnectionState() { - return common::visit( + return visitIo( [](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); }, u_); } MutableModes &IoStatementState::mutableModes() { - return common::visit( + return visitIo( [](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_); } bool IoStatementState::BeginReadingRecord() { - return common::visit( - [](auto &x) { return x.get().BeginReadingRecord(); }, u_); + return visitIo([](auto &x) { return x.get().BeginReadingRecord(); }, u_); } IoErrorHandler &IoStatementState::GetIoErrorHandler() const { - return common::visit( + return visitIo( [](auto &x) -> IoErrorHandler & { return static_cast(x.get()); }, @@ -537,8 +535,7 @@ IoErrorHandler &IoStatementState::GetIoErrorHandler() const { } ExternalFileUnit *IoStatementState::GetExternalFileUnit() const { - return common::visit( - [](auto &x) { return x.get().GetExternalFileUnit(); }, u_); + return visitIo([](auto &x) { return x.get().GetExternalFileUnit(); }, u_); } Fortran::common::optional IoStatementState::GetCurrentChar( @@ -664,28 +661,26 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) { bool IoStatementState::Inquire( InquiryKeywordHash inquiry, char *out, std::size_t chars) { - return common::visit( + return visitIo( [&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_); } bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) { - return common::visit( - [&](auto &x) { return x.get().Inquire(inquiry, out); }, u_); + return visitIo([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_); } bool IoStatementState::Inquire( InquiryKeywordHash inquiry, std::int64_t id, bool &out) { - return common::visit( + return visitIo( [&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_); } bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) { - return common::visit( - [&](auto &x) { return x.get().Inquire(inquiry, n); }, u_); + return visitIo([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_); } std::int64_t IoStatementState::InquirePos() { - return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_); + return visitIo([&](auto &x) { return x.get().InquirePos(); }, u_); } void IoStatementState::GotChar(int n) { diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h index e00d54980aae5..4e17cee2becf8 100644 --- a/flang/runtime/io-stmt.h +++ b/flang/runtime/io-stmt.h @@ -16,15 +16,19 @@ #include "format.h" #include "internal-unit.h" #include "io-error.h" +#include "flang/Common/idioms.h" #include "flang/Common/optional.h" #include "flang/Common/reference-wrapper.h" -#include "flang/Common/visit.h" #include "flang/Runtime/descriptor.h" #include "flang/Runtime/io-api.h" #include #include #include +// I/O statement state classes that may be instantiated during execution +// on an offload device have this trait: +CLASS_TRAIT(AvailableOnDevice) + namespace Fortran::runtime::io { class ExternalFileUnit; @@ -52,15 +56,23 @@ template class ChildListIoStatementState; template class ChildUnformattedIoStatementState; struct InputStatementState {}; -struct OutputStatementState {}; +struct OutputStatementState { + using AvailableOnDevice = std::true_type; +}; template using IoDirectionState = std::conditional_t; // Common state for all kinds of formatted I/O template class FormattedIoStatementState {}; +template <> class FormattedIoStatementState { +public: + using AvailableOnDevice = std::true_type; +}; + template <> class FormattedIoStatementState { public: + using AvailableOnDevice = std::true_type; std::size_t GetEditDescriptorChars() const; void GotChar(int); @@ -113,10 +125,19 @@ class IoStatementState { // N.B.: this also works with base classes template A *get_if() const { - return common::visit( - [](auto &x) -> A * { + [[maybe_unused]] std::size_t index{u_.index()}; + return Fortran::common::visit( + [=](auto &x) -> A * { if constexpr (std::is_convertible_v) { - return &x.get(); +#if defined(RT_DEVICE_COMPILATION) + if constexpr (!AvailableOnDevice>) { + terminateOnDevice(__FILE__, __LINE__, index); + } else { +#endif + return &x.get(); +#if defined(RT_DEVICE_COMPILATION) + } +#endif } return nullptr; }, @@ -211,6 +232,40 @@ class IoStatementState { } private: +#if RT_DEVICE_COMPILATION + static RT_API_ATTRS void terminateOnDevice( + const char *sourceFile, int sourceLine, std::size_t index) { + // %zd is not supported by device printf. + Terminator{sourceFile, sourceLine}.Crash( + "Unexpected IO statement variant (index %d) during device execution", + static_cast(index)); + } +#endif + + // Define special visitor for the variants of IoStatementState. + // During the device code compilation the visitor only allows + // visiting those variants that have AvailableOnDevice trait + // are supported on the device. + template + static inline RT_API_ATTRS auto visitIo(VISITOR &&visitor, VARIANT &&u) + -> decltype(visitor(std::get<0>(std::forward(u)))) { + using Result = decltype(visitor(std::get<0>(std::forward(u)))); + [[maybe_unused]] std::size_t index{u.index()}; + return Fortran::common::visit( + [&](auto &x) -> Result { +#if defined(RT_DEVICE_COMPILATION) + if constexpr (!AvailableOnDevice>) { + terminateOnDevice(__FILE__, __LINE__, index); + } else { +#endif + return visitor(x); +#if defined(RT_DEVICE_COMPILATION) + } +#endif + }, + std::forward(u)); + } + std::variant, Fortran::common::reference_wrapper, Fortran::common::reference_wrapper, @@ -296,6 +351,7 @@ template <> class ListDirectedStatementState : public FormattedIoStatementState { public: + using AvailableOnDevice = std::true_type; bool EmitLeadingSpaceOrAdvance( IoStatementState &, std::size_t = 1, bool isCharacter = false); Fortran::common::optional GetNextDataEdit( @@ -314,6 +370,7 @@ template <> class ListDirectedStatementState : public FormattedIoStatementState { public: + using AvailableOnDevice = std::false_type; bool inNamelistSequence() const { return inNamelistSequence_; } int EndIoStatement(); @@ -351,6 +408,8 @@ template class InternalIoStatementState : public IoStatementBase, public IoDirectionState { public: + using AvailableOnDevice = std::conditional_t; using Buffer = std::conditional_t; InternalIoStatementState(Buffer, std::size_t, @@ -379,6 +438,8 @@ class InternalFormattedIoStatementState : public InternalIoStatementState, public FormattedIoStatementState { public: + using AvailableOnDevice = std::conditional_t; using CharType = CHAR; using typename InternalIoStatementState::Buffer; InternalFormattedIoStatementState(Buffer internal, std::size_t internalLength, @@ -407,6 +468,8 @@ template class InternalListIoStatementState : public InternalIoStatementState, public ListDirectedStatementState { public: + using AvailableOnDevice = std::conditional_t; using typename InternalIoStatementState::Buffer; InternalListIoStatementState(Buffer internal, std::size_t internalLength, const char *sourceFile = nullptr, int sourceLine = 0); @@ -424,6 +487,7 @@ class InternalListIoStatementState : public InternalIoStatementState, class ExternalIoStatementBase : public IoStatementBase { public: + using AvailableOnDevice = std::false_type; ExternalIoStatementBase( ExternalFileUnit &, const char *sourceFile = nullptr, int sourceLine = 0); ExternalFileUnit &unit() { return unit_; } @@ -444,6 +508,7 @@ template class ExternalIoStatementState : public ExternalIoStatementBase, public IoDirectionState { public: + using AvailableOnDevice = std::false_type; ExternalIoStatementState( ExternalFileUnit &, const char *sourceFile = nullptr, int sourceLine = 0); MutableModes &mutableModes() { return mutableModes_; } @@ -470,6 +535,7 @@ class ExternalFormattedIoStatementState : public ExternalIoStatementState, public FormattedIoStatementState { public: + using AvailableOnDevice = std::false_type; using CharType = CHAR; ExternalFormattedIoStatementState(ExternalFileUnit &, const CharType *format, std::size_t formatLength, const Descriptor *formatDescriptor = nullptr, @@ -489,6 +555,7 @@ template class ExternalListIoStatementState : public ExternalIoStatementState, public ListDirectedStatementState { public: + using AvailableOnDevice = std::false_type; using ExternalIoStatementState::ExternalIoStatementState; using ListDirectedStatementState::GetNextDataEdit; int EndIoStatement(); @@ -498,6 +565,7 @@ template class ExternalUnformattedIoStatementState : public ExternalIoStatementState { public: + using AvailableOnDevice = std::false_type; using ExternalIoStatementState::ExternalIoStatementState; bool Receive(char *, std::size_t, std::size_t elementBytes = 0); }; @@ -506,6 +574,7 @@ template class ChildIoStatementState : public IoStatementBase, public IoDirectionState { public: + using AvailableOnDevice = std::false_type; ChildIoStatementState( ChildIo &, const char *sourceFile = nullptr, int sourceLine = 0); ChildIo &child() { return child_; } @@ -526,6 +595,7 @@ template class ChildFormattedIoStatementState : public ChildIoStatementState, public FormattedIoStatementState { public: + using AvailableOnDevice = std::false_type; using CharType = CHAR; ChildFormattedIoStatementState(ChildIo &, const CharType *format, std::size_t formatLength, const Descriptor *formatDescriptor = nullptr, @@ -548,6 +618,7 @@ template class ChildListIoStatementState : public ChildIoStatementState, public ListDirectedStatementState { public: + using AvailableOnDevice = std::false_type; using ChildIoStatementState::ChildIoStatementState; using ListDirectedStatementState::GetNextDataEdit; int EndIoStatement(); @@ -556,6 +627,7 @@ class ChildListIoStatementState : public ChildIoStatementState, template class ChildUnformattedIoStatementState : public ChildIoStatementState { public: + using AvailableOnDevice = std::false_type; using ChildIoStatementState::ChildIoStatementState; bool Receive(char *, std::size_t, std::size_t elementBytes = 0); }; @@ -563,6 +635,7 @@ class ChildUnformattedIoStatementState : public ChildIoStatementState { // OPEN class OpenStatementState : public ExternalIoStatementBase { public: + using AvailableOnDevice = std::false_type; OpenStatementState(ExternalFileUnit &unit, bool wasExtant, bool isNewUnit, const char *sourceFile = nullptr, int sourceLine = 0) : ExternalIoStatementBase{unit, sourceFile, sourceLine}, @@ -594,6 +667,7 @@ class OpenStatementState : public ExternalIoStatementBase { class CloseStatementState : public ExternalIoStatementBase { public: + using AvailableOnDevice = std::false_type; CloseStatementState(ExternalFileUnit &unit, const char *sourceFile = nullptr, int sourceLine = 0) : ExternalIoStatementBase{unit, sourceFile, sourceLine} {} @@ -608,6 +682,7 @@ class CloseStatementState : public ExternalIoStatementBase { // and recoverable BACKSPACE(bad unit) class NoUnitIoStatementState : public IoStatementBase { public: + using AvailableOnDevice = std::false_type; IoStatementState &ioStatementState() { return ioStatementState_; } MutableModes &mutableModes() { return connection_.modes; } ConnectionState &GetConnectionState() { return connection_; } @@ -630,6 +705,7 @@ class NoUnitIoStatementState : public IoStatementBase { class NoopStatementState : public NoUnitIoStatementState { public: + using AvailableOnDevice = std::false_type; NoopStatementState( const char *sourceFile = nullptr, int sourceLine = 0, int unitNumber = -1) : NoUnitIoStatementState{*this, sourceFile, sourceLine, unitNumber} {} @@ -674,6 +750,7 @@ extern template class FormatControl< class InquireUnitState : public ExternalIoStatementBase { public: + using AvailableOnDevice = std::false_type; InquireUnitState(ExternalFileUnit &unit, const char *sourceFile = nullptr, int sourceLine = 0); bool Inquire(InquiryKeywordHash, char *, std::size_t); @@ -684,6 +761,7 @@ class InquireUnitState : public ExternalIoStatementBase { class InquireNoUnitState : public NoUnitIoStatementState { public: + using AvailableOnDevice = std::false_type; InquireNoUnitState(const char *sourceFile = nullptr, int sourceLine = 0, int badUnitNumber = -1); bool Inquire(InquiryKeywordHash, char *, std::size_t); @@ -694,6 +772,7 @@ class InquireNoUnitState : public NoUnitIoStatementState { class InquireUnconnectedFileState : public NoUnitIoStatementState { public: + using AvailableOnDevice = std::false_type; InquireUnconnectedFileState(OwningPtr &&path, const char *sourceFile = nullptr, int sourceLine = 0); bool Inquire(InquiryKeywordHash, char *, std::size_t); @@ -708,6 +787,7 @@ class InquireUnconnectedFileState : public NoUnitIoStatementState { class InquireIOLengthState : public NoUnitIoStatementState, public OutputStatementState { public: + using AvailableOnDevice = std::false_type; InquireIOLengthState(const char *sourceFile = nullptr, int sourceLine = 0); std::size_t bytes() const { return bytes_; } bool Emit(const char *, std::size_t bytes, std::size_t elementBytes = 0); @@ -718,6 +798,7 @@ class InquireIOLengthState : public NoUnitIoStatementState, class ExternalMiscIoStatementState : public ExternalIoStatementBase { public: + using AvailableOnDevice = std::false_type; enum Which { Flush, Backspace, Endfile, Rewind, Wait }; ExternalMiscIoStatementState(ExternalFileUnit &unit, Which which, const char *sourceFile = nullptr, int sourceLine = 0) @@ -731,6 +812,7 @@ class ExternalMiscIoStatementState : public ExternalIoStatementBase { class ErroneousIoStatementState : public IoStatementBase { public: + using AvailableOnDevice = std::false_type; explicit ErroneousIoStatementState(Iostat iostat, ExternalFileUnit *unit = nullptr, const char *sourceFile = nullptr, int sourceLine = 0)