Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[flang][runtime] Added Fortran::common::reference_wrapper for use on …
…device. This is a simplified implementation of std::reference_wrapper that can be used in the offload builds for the device code. The methods are properly marked with RT_API_ATTRS so that the device compilation succedes. Reviewers: jeanPerier, klausler Reviewed By: jeanPerier Pull Request: #85178
- Loading branch information
Showing
2 changed files
with
148 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// clang-format off | ||
// | ||
// Implementation of std::reference_wrapper borrowed from libcu++ | ||
// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1 | ||
// with modifications. | ||
// | ||
// The original source code is distributed under the Apache License v2.0 | ||
// with LLVM Exceptions. | ||
// | ||
// TODO: using libcu++ is the best option for CUDA, but there is a couple | ||
// of issues: | ||
// * The include paths need to be set up such that all STD header files | ||
// are taken from libcu++. | ||
// * cuda:: namespace need to be forced for all std:: references. | ||
// | ||
// clang-format on | ||
|
||
#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H | ||
#define FORTRAN_COMMON_REFERENCE_WRAPPER_H | ||
|
||
#include "flang/Runtime/api-attrs.h" | ||
#include <functional> | ||
#include <type_traits> | ||
|
||
#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \ | ||
(defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__) | ||
#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1 | ||
#endif | ||
|
||
namespace Fortran::common { | ||
|
||
template <class _Tp> | ||
using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>; | ||
template <class _Tp, class _Up> | ||
struct __is_same_uncvref | ||
: std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {}; | ||
|
||
#if STD_REFERENCE_WRAPPER_UNSUPPORTED | ||
template <class _Tp> class reference_wrapper { | ||
public: | ||
// types | ||
typedef _Tp type; | ||
|
||
private: | ||
type *__f_; | ||
|
||
static RT_API_ATTRS void __fun(_Tp &); | ||
static void __fun(_Tp &&) = delete; | ||
|
||
public: | ||
template <class _Up, | ||
class = | ||
std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value, | ||
decltype(__fun(std::declval<_Up>()))>> | ||
constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) { | ||
type &__f = static_cast<_Up &&>(__u); | ||
__f_ = std::addressof(__f); | ||
} | ||
|
||
// access | ||
constexpr RT_API_ATTRS operator type &() const { return *__f_; } | ||
constexpr RT_API_ATTRS type &get() const { return *__f_; } | ||
|
||
// invoke | ||
template <class... _ArgTypes> | ||
constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...> | ||
operator()(_ArgTypes &&...__args) const { | ||
return std::invoke(get(), std::forward<_ArgTypes>(__args)...); | ||
} | ||
}; | ||
|
||
template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>; | ||
|
||
template <class _Tp> | ||
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) { | ||
return reference_wrapper<_Tp>(__t); | ||
} | ||
|
||
template <class _Tp> | ||
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref( | ||
reference_wrapper<_Tp> __t) { | ||
return __t; | ||
} | ||
|
||
template <class _Tp> | ||
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref( | ||
const _Tp &__t) { | ||
return reference_wrapper<const _Tp>(__t); | ||
} | ||
|
||
template <class _Tp> | ||
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref( | ||
reference_wrapper<_Tp> __t) { | ||
return __t; | ||
} | ||
|
||
template <class _Tp> void ref(const _Tp &&) = delete; | ||
template <class _Tp> void cref(const _Tp &&) = delete; | ||
#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED | ||
using std::cref; | ||
using std::ref; | ||
using std::reference_wrapper; | ||
#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED | ||
|
||
} // namespace Fortran::common | ||
|
||
#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters