From d8f97c067c9488f00bfaa17086c672d1fb7106d9 Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Fri, 15 Mar 2024 14:41:47 -0700 Subject: [PATCH] [flang][runtime] Added Fortran::common::reference_wrapper for use on device. This is a simplified implementation of std::reference_wrapper that can be used in the offload builds for the device code. The methods are properly marked with RT_API_ATTRS so that the device compilation succedes. Reviewers: jeanPerier, klausler Reviewed By: jeanPerier Pull Request: https://github.com/llvm/llvm-project/pull/85178 --- .../include/flang/Common/reference-wrapper.h | 114 ++++++++++++++++++ flang/runtime/io-stmt.h | 59 +++++---- 2 files changed, 148 insertions(+), 25 deletions(-) create mode 100644 flang/include/flang/Common/reference-wrapper.h diff --git a/flang/include/flang/Common/reference-wrapper.h b/flang/include/flang/Common/reference-wrapper.h new file mode 100644 index 0000000000000..66f924662d961 --- /dev/null +++ b/flang/include/flang/Common/reference-wrapper.h @@ -0,0 +1,114 @@ +//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// clang-format off +// +// Implementation of std::reference_wrapper borrowed from libcu++ +// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1 +// with modifications. +// +// The original source code is distributed under the Apache License v2.0 +// with LLVM Exceptions. +// +// TODO: using libcu++ is the best option for CUDA, but there is a couple +// of issues: +// * The include paths need to be set up such that all STD header files +// are taken from libcu++. +// * cuda:: namespace need to be forced for all std:: references. +// +// clang-format on + +#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H +#define FORTRAN_COMMON_REFERENCE_WRAPPER_H + +#include "flang/Runtime/api-attrs.h" +#include +#include + +#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \ + (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__) +#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1 +#endif + +namespace Fortran::common { + +template +using __remove_cvref_t = std::remove_cv_t>; +template +struct __is_same_uncvref + : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {}; + +#if STD_REFERENCE_WRAPPER_UNSUPPORTED +template class reference_wrapper { +public: + // types + typedef _Tp type; + +private: + type *__f_; + + static RT_API_ATTRS void __fun(_Tp &); + static void __fun(_Tp &&) = delete; + +public: + template ::value, + decltype(__fun(std::declval<_Up>()))>> + constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) { + type &__f = static_cast<_Up &&>(__u); + __f_ = std::addressof(__f); + } + + // access + constexpr RT_API_ATTRS operator type &() const { return *__f_; } + constexpr RT_API_ATTRS type &get() const { return *__f_; } + + // invoke + template + constexpr RT_API_ATTRS typename std::invoke_result_t + operator()(_ArgTypes &&...__args) const { + return std::invoke(get(), std::forward<_ArgTypes>(__args)...); + } +}; + +template reference_wrapper(_Tp &) -> reference_wrapper<_Tp>; + +template +inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) { + return reference_wrapper<_Tp>(__t); +} + +template +inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref( + reference_wrapper<_Tp> __t) { + return __t; +} + +template +inline constexpr RT_API_ATTRS reference_wrapper cref( + const _Tp &__t) { + return reference_wrapper(__t); +} + +template +inline constexpr RT_API_ATTRS reference_wrapper cref( + reference_wrapper<_Tp> __t) { + return __t; +} + +template void ref(const _Tp &&) = delete; +template void cref(const _Tp &&) = delete; +#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED +using std::cref; +using std::ref; +using std::reference_wrapper; +#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED + +} // namespace Fortran::common + +#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h index 0477c32b3b53a..e00d54980aae5 100644 --- a/flang/runtime/io-stmt.h +++ b/flang/runtime/io-stmt.h @@ -17,6 +17,7 @@ #include "internal-unit.h" #include "io-error.h" #include "flang/Common/optional.h" +#include "flang/Common/reference-wrapper.h" #include "flang/Common/visit.h" #include "flang/Runtime/descriptor.h" #include "flang/Runtime/io-api.h" @@ -210,39 +211,47 @@ class IoStatementState { } private: - std::variant, - std::reference_wrapper, - std::reference_wrapper, - std::reference_wrapper< + std::variant, + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper< InternalFormattedIoStatementState>, - std::reference_wrapper< + Fortran::common::reference_wrapper< InternalFormattedIoStatementState>, - std::reference_wrapper>, - std::reference_wrapper>, - std::reference_wrapper< + Fortran::common::reference_wrapper< + InternalListIoStatementState>, + Fortran::common::reference_wrapper< + InternalListIoStatementState>, + Fortran::common::reference_wrapper< ExternalFormattedIoStatementState>, - std::reference_wrapper< + Fortran::common::reference_wrapper< ExternalFormattedIoStatementState>, - std::reference_wrapper>, - std::reference_wrapper>, - std::reference_wrapper< + Fortran::common::reference_wrapper< + ExternalListIoStatementState>, + Fortran::common::reference_wrapper< + ExternalListIoStatementState>, + Fortran::common::reference_wrapper< ExternalUnformattedIoStatementState>, - std::reference_wrapper< + Fortran::common::reference_wrapper< ExternalUnformattedIoStatementState>, - std::reference_wrapper>, - std::reference_wrapper>, - std::reference_wrapper>, - std::reference_wrapper>, - std::reference_wrapper< + Fortran::common::reference_wrapper< + ChildFormattedIoStatementState>, + Fortran::common::reference_wrapper< + ChildFormattedIoStatementState>, + Fortran::common::reference_wrapper< + ChildListIoStatementState>, + Fortran::common::reference_wrapper< + ChildListIoStatementState>, + Fortran::common::reference_wrapper< ChildUnformattedIoStatementState>, - std::reference_wrapper< + Fortran::common::reference_wrapper< ChildUnformattedIoStatementState>, - std::reference_wrapper, - std::reference_wrapper, - std::reference_wrapper, - std::reference_wrapper, - std::reference_wrapper, - std::reference_wrapper> + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper, + Fortran::common::reference_wrapper> u_; };