-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][runtime] Added custom visitor for IoStatementState variants. #85179
Conversation
Created using spr 1.3.4
@llvm/pr-subscribers-flang-runtime Author: Slava Zakharin (vzakhari) ChangesThe visitor only allows Internal.*IoStatementState variants to be visited. I had problems parameterizing the Fortran::common visitor to limit Full diff: https://github.com/llvm/llvm-project/pull/85179.diff 2 Files Affected:
diff --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 075d7b5ae518a4..efefbc5e1a1c08 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -467,69 +467,66 @@ int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
}
Fortran::common::optional<DataEdit> IoStatementState::GetNextDataEdit(int n) {
- return common::visit(
- [&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
+ return visit([&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
}
bool IoStatementState::Emit(
const char *data, std::size_t bytes, std::size_t elementBytes) {
- return common::visit(
+ return visit(
[=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_);
}
bool IoStatementState::Receive(
char *data, std::size_t n, std::size_t elementBytes) {
- return common::visit(
+ return visit(
[=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_);
}
std::size_t IoStatementState::GetNextInputBytes(const char *&p) {
- return common::visit(
- [&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
+ return visit([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
}
bool IoStatementState::AdvanceRecord(int n) {
- return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
+ return visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
}
void IoStatementState::BackspaceRecord() {
- common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
+ visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
}
void IoStatementState::HandleRelativePosition(std::int64_t n) {
- common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
+ visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
}
void IoStatementState::HandleAbsolutePosition(std::int64_t n) {
- common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
+ visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
}
void IoStatementState::CompleteOperation() {
- common::visit([](auto &x) { x.get().CompleteOperation(); }, u_);
+ visit([](auto &x) { x.get().CompleteOperation(); }, u_);
}
int IoStatementState::EndIoStatement() {
- return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
+ return visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
}
ConnectionState &IoStatementState::GetConnectionState() {
- return common::visit(
+ return visit(
[](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); },
u_);
}
MutableModes &IoStatementState::mutableModes() {
- return common::visit(
+ return visit(
[](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_);
}
bool IoStatementState::BeginReadingRecord() {
- return common::visit(
- [](auto &x) { return x.get().BeginReadingRecord(); }, u_);
+ return visit([](auto &x) { return x.get().BeginReadingRecord(); }, u_);
}
IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
- return common::visit(
+ return visit(
[](auto &x) -> IoErrorHandler & {
return static_cast<IoErrorHandler &>(x.get());
},
@@ -537,8 +534,7 @@ IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
}
ExternalFileUnit *IoStatementState::GetExternalFileUnit() const {
- return common::visit(
- [](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
+ return visit([](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
}
Fortran::common::optional<char32_t> IoStatementState::GetCurrentChar(
@@ -664,28 +660,25 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) {
bool IoStatementState::Inquire(
InquiryKeywordHash inquiry, char *out, std::size_t chars) {
- return common::visit(
+ return visit(
[&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_);
}
bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) {
- return common::visit(
- [&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
+ return visit([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
}
bool IoStatementState::Inquire(
InquiryKeywordHash inquiry, std::int64_t id, bool &out) {
- return common::visit(
- [&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
+ return visit([&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
}
bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) {
- return common::visit(
- [&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
+ return visit([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
}
std::int64_t IoStatementState::InquirePos() {
- return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_);
+ return visit([&](auto &x) { return x.get().InquirePos(); }, u_);
}
void IoStatementState::GotChar(int n) {
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index e00d54980aae59..7fecf4d9e41754 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -18,7 +18,6 @@
#include "io-error.h"
#include "flang/Common/optional.h"
#include "flang/Common/reference-wrapper.h"
-#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
#include <functional>
@@ -113,7 +112,7 @@ class IoStatementState {
// N.B.: this also works with base classes
template <typename A> A *get_if() const {
- return common::visit(
+ return visit(
[](auto &x) -> A * {
if constexpr (std::is_convertible_v<decltype(x.get()), A &>) {
return &x.get();
@@ -211,6 +210,61 @@ class IoStatementState {
}
private:
+ // Define special visitor for the variants of IoStatementState.
+ // During the device code compilation the visitor only allows
+ // visiting those variants that are supported on the device.
+ // In particular, only the internal IO variants are supported.
+ // TODO: parameterize Fortran::common::log2visit instead of
+ // creating a copy here.
+ template <class T, class... Ts>
+ struct is_any_type : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {};
+
+ template <std::size_t LOW, std::size_t HIGH, typename RESULT,
+ typename VISITOR, typename VARIANT>
+ static inline RT_API_ATTRS RESULT Log2VisitHelper(
+ VISITOR &&visitor, std::size_t which, VARIANT &&u) {
+#if !defined(RT_DEVICE_COMPILATION)
+ constexpr bool isDevice{false};
+#else
+ constexpr bool isDevice{true};
+#endif
+ if constexpr (LOW == HIGH) {
+ if constexpr (!isDevice ||
+ is_any_type<
+ std::variant_alternative_t<LOW, std::decay_t<decltype(u)>>,
+ Fortran::common::reference_wrapper<
+ InternalListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ InternalFormattedIoStatementState<Direction::Output>>>::
+ value) {
+ return visitor(std::get<LOW>(std::forward<VARIANT>(u)));
+ } else {
+ Terminator{__FILE__, __LINE__}.Crash(
+ "not implemented yet: IoStatementState variant %d\n",
+ static_cast<int>(LOW));
+ }
+ } else {
+ static constexpr std::size_t mid{(HIGH + LOW) / 2};
+ if (which <= mid) {
+ return Log2VisitHelper<LOW, mid, RESULT>(
+ std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+ } else {
+ return Log2VisitHelper<(mid + 1), HIGH, RESULT>(
+ std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+ }
+ }
+ }
+
+ template <typename VISITOR, typename VARIANT>
+ static inline RT_API_ATTRS auto visit(VISITOR &&visitor, VARIANT &&u)
+ -> decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))) {
+ using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u))));
+ static constexpr std::size_t high{
+ std::variant_size_v<std::decay_t<decltype(u)>> - 1};
+ return Log2VisitHelper<0, high, Result>(
+ std::forward<VISITOR>(visitor), u.index(), std::forward<VARIANT>(u));
+ }
+
std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
Fortran::common::reference_wrapper<CloseStatementState>,
Fortran::common::reference_wrapper<NoopStatementState>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to have the types that are available on the device declare themselves so in their declarations via a member or (better) inherited trait.
The big variant union in IoStatementState
could omit the host-only options when compiled for the device.
Like this. I will add the type trait for those types that are available on the device.
I am not sure what you are suggesting. I cannot comment out (e.g. with
|
Created using spr 1.3.4
Will use alternative solution. |
The visitor only allows Internal.*IoStatementState variants to be visited.
In case another variant is met a runtime error is produced.
During the device compilation the other variants' classes are not referenced,
which, for example, helps to avoid warnings about host only
methods referenced in device code.
I had problems parameterizing the Fortran::common visitor to limit
the allowed variants, but I can give it another try if creating
a copy looks inappropriate.