Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][runtime] Added custom visitor for IoStatementState variants. #85179

Conversation

vzakhari
Copy link
Contributor

The visitor only allows Internal.*IoStatementState variants to be visited.
In case another variant is met a runtime error is produced.
During the device compilation the other variants' classes are not referenced,
which, for example, helps to avoid warnings about host only
methods referenced in device code.

I had problems parameterizing the Fortran::common visitor to limit
the allowed variants, but I can give it another try if creating
a copy looks inappropriate.

Created using spr 1.3.4
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category labels Mar 14, 2024
@llvmbot
Copy link

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

Changes

The visitor only allows Internal.*IoStatementState variants to be visited.
In case another variant is met a runtime error is produced.
During the device compilation the other variants' classes are not referenced,
which, for example, helps to avoid warnings about host only
methods referenced in device code.

I had problems parameterizing the Fortran::common visitor to limit
the allowed variants, but I can give it another try if creating
a copy looks inappropriate.


Full diff: https://github.com/llvm/llvm-project/pull/85179.diff

2 Files Affected:

  • (modified) flang/runtime/io-stmt.cpp (+20-27)
  • (modified) flang/runtime/io-stmt.h (+56-2)
diff --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 075d7b5ae518a4..efefbc5e1a1c08 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -467,69 +467,66 @@ int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
 }
 
 Fortran::common::optional<DataEdit> IoStatementState::GetNextDataEdit(int n) {
-  return common::visit(
-      [&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
+  return visit([&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
 }
 
 bool IoStatementState::Emit(
     const char *data, std::size_t bytes, std::size_t elementBytes) {
-  return common::visit(
+  return visit(
       [=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_);
 }
 
 bool IoStatementState::Receive(
     char *data, std::size_t n, std::size_t elementBytes) {
-  return common::visit(
+  return visit(
       [=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_);
 }
 
 std::size_t IoStatementState::GetNextInputBytes(const char *&p) {
-  return common::visit(
-      [&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
+  return visit([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
 }
 
 bool IoStatementState::AdvanceRecord(int n) {
-  return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
+  return visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
 }
 
 void IoStatementState::BackspaceRecord() {
-  common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
+  visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
 }
 
 void IoStatementState::HandleRelativePosition(std::int64_t n) {
-  common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
+  visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
 }
 
 void IoStatementState::HandleAbsolutePosition(std::int64_t n) {
-  common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
+  visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
 }
 
 void IoStatementState::CompleteOperation() {
-  common::visit([](auto &x) { x.get().CompleteOperation(); }, u_);
+  visit([](auto &x) { x.get().CompleteOperation(); }, u_);
 }
 
 int IoStatementState::EndIoStatement() {
-  return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
+  return visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
 }
 
 ConnectionState &IoStatementState::GetConnectionState() {
-  return common::visit(
+  return visit(
       [](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); },
       u_);
 }
 
 MutableModes &IoStatementState::mutableModes() {
-  return common::visit(
+  return visit(
       [](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_);
 }
 
 bool IoStatementState::BeginReadingRecord() {
-  return common::visit(
-      [](auto &x) { return x.get().BeginReadingRecord(); }, u_);
+  return visit([](auto &x) { return x.get().BeginReadingRecord(); }, u_);
 }
 
 IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
-  return common::visit(
+  return visit(
       [](auto &x) -> IoErrorHandler & {
         return static_cast<IoErrorHandler &>(x.get());
       },
@@ -537,8 +534,7 @@ IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
 }
 
 ExternalFileUnit *IoStatementState::GetExternalFileUnit() const {
-  return common::visit(
-      [](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
+  return visit([](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
 }
 
 Fortran::common::optional<char32_t> IoStatementState::GetCurrentChar(
@@ -664,28 +660,25 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) {
 
 bool IoStatementState::Inquire(
     InquiryKeywordHash inquiry, char *out, std::size_t chars) {
-  return common::visit(
+  return visit(
       [&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_);
 }
 
 bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) {
-  return common::visit(
-      [&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
+  return visit([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
 }
 
 bool IoStatementState::Inquire(
     InquiryKeywordHash inquiry, std::int64_t id, bool &out) {
-  return common::visit(
-      [&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
+  return visit([&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
 }
 
 bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) {
-  return common::visit(
-      [&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
+  return visit([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
 }
 
 std::int64_t IoStatementState::InquirePos() {
-  return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_);
+  return visit([&](auto &x) { return x.get().InquirePos(); }, u_);
 }
 
 void IoStatementState::GotChar(int n) {
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index e00d54980aae59..7fecf4d9e41754 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -18,7 +18,6 @@
 #include "io-error.h"
 #include "flang/Common/optional.h"
 #include "flang/Common/reference-wrapper.h"
-#include "flang/Common/visit.h"
 #include "flang/Runtime/descriptor.h"
 #include "flang/Runtime/io-api.h"
 #include <functional>
@@ -113,7 +112,7 @@ class IoStatementState {
 
   // N.B.: this also works with base classes
   template <typename A> A *get_if() const {
-    return common::visit(
+    return visit(
         [](auto &x) -> A * {
           if constexpr (std::is_convertible_v<decltype(x.get()), A &>) {
             return &x.get();
@@ -211,6 +210,61 @@ class IoStatementState {
   }
 
 private:
+  // Define special visitor for the variants of IoStatementState.
+  // During the device code compilation the visitor only allows
+  // visiting those variants that are supported on the device.
+  // In particular, only the internal IO variants are supported.
+  // TODO: parameterize Fortran::common::log2visit instead of
+  //       creating a copy here.
+  template <class T, class... Ts>
+  struct is_any_type : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {};
+
+  template <std::size_t LOW, std::size_t HIGH, typename RESULT,
+      typename VISITOR, typename VARIANT>
+  static inline RT_API_ATTRS RESULT Log2VisitHelper(
+      VISITOR &&visitor, std::size_t which, VARIANT &&u) {
+#if !defined(RT_DEVICE_COMPILATION)
+    constexpr bool isDevice{false};
+#else
+    constexpr bool isDevice{true};
+#endif
+    if constexpr (LOW == HIGH) {
+      if constexpr (!isDevice ||
+          is_any_type<
+              std::variant_alternative_t<LOW, std::decay_t<decltype(u)>>,
+              Fortran::common::reference_wrapper<
+                  InternalListIoStatementState<Direction::Output>>,
+              Fortran::common::reference_wrapper<
+                  InternalFormattedIoStatementState<Direction::Output>>>::
+              value) {
+        return visitor(std::get<LOW>(std::forward<VARIANT>(u)));
+      } else {
+        Terminator{__FILE__, __LINE__}.Crash(
+            "not implemented yet: IoStatementState variant %d\n",
+            static_cast<int>(LOW));
+      }
+    } else {
+      static constexpr std::size_t mid{(HIGH + LOW) / 2};
+      if (which <= mid) {
+        return Log2VisitHelper<LOW, mid, RESULT>(
+            std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+      } else {
+        return Log2VisitHelper<(mid + 1), HIGH, RESULT>(
+            std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+      }
+    }
+  }
+
+  template <typename VISITOR, typename VARIANT>
+  static inline RT_API_ATTRS auto visit(VISITOR &&visitor, VARIANT &&u)
+      -> decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))) {
+    using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u))));
+    static constexpr std::size_t high{
+        std::variant_size_v<std::decay_t<decltype(u)>> - 1};
+    return Log2VisitHelper<0, high, Result>(
+        std::forward<VISITOR>(visitor), u.index(), std::forward<VARIANT>(u));
+  }
+
   std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
       Fortran::common::reference_wrapper<CloseStatementState>,
       Fortran::common::reference_wrapper<NoopStatementState>,

@vzakhari vzakhari requested a review from klausler March 14, 2024 05:33
Copy link
Contributor

@klausler klausler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have the types that are available on the device declare themselves so in their declarations via a member or (better) inherited trait.

The big variant union in IoStatementState could omit the host-only options when compiled for the device.

@vzakhari
Copy link
Contributor Author

It would be better to have the types that are available on the device declare themselves so in their declarations via a member or (better) inherited trait.

Like this. I will add the type trait for those types that are available on the device.

The big variant union in IoStatementState could omit the host-only options when compiled for the device.

I am not sure what you are suggesting. I cannot comment out (e.g. with RT_DEVICE_OUTPUT check) the members of std::variant<...> u_ union, because this will require commenting out all class declarations that try to explicitly access/emplace an instance of the commented out class into the union member. For example,

class NoUnitIoStatementState : public IoStatementBase {
...
  template <typename A>
  NoUnitIoStatementState(A &stmt, const char *sourceFile = nullptr,
      int sourceLine = 0, int badUnitNumber = -1)
      : IoStatementBase{sourceFile, sourceLine}, ioStatementState_{stmt},
        badUnitNumber_{badUnitNumber} {}
...
};

class NoopStatementState : public NoUnitIoStatementState {
public:
  NoopStatementState(
      const char *sourceFile = nullptr, int sourceLine = 0, int unitNumber = -1)
      : NoUnitIoStatementState{*this, sourceFile, sourceLine, unitNumber} {}
...
};

ioStatementState_{stmt} is invalid if the union does not have NoUnitIoStatementState variant. There are more examples like this, and there is a lot of code that will need to be commented out for the device compilation if I modify the union declaration. Maybe I misunderstood your comment, though.

@vzakhari vzakhari requested a review from klausler March 15, 2024 19:20
Created using spr 1.3.4
Created using spr 1.3.4
Created using spr 1.3.4
@vzakhari
Copy link
Contributor Author

Will use alternative solution.

@vzakhari vzakhari closed this Mar 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants