Skip to content

Commit

Permalink
[flang][runtime] Added Fortran::common::reference_wrapper for use on …
Browse files Browse the repository at this point in the history
…device.

This is a simplified implementation of std::reference_wrapper that can be used
in the offload builds for the device code. The methods are properly
marked with RT_API_ATTRS so that the device compilation succedes.

Reviewers: jeanPerier, klausler

Reviewed By: jeanPerier

Pull Request: #85178
  • Loading branch information
vzakhari committed Mar 15, 2024
1 parent 6e1959d commit d8f97c0
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 25 deletions.
114 changes: 114 additions & 0 deletions flang/include/flang/Common/reference-wrapper.h
@@ -0,0 +1,114 @@
//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// clang-format off
//
// Implementation of std::reference_wrapper borrowed from libcu++
// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
// with modifications.
//
// The original source code is distributed under the Apache License v2.0
// with LLVM Exceptions.
//
// TODO: using libcu++ is the best option for CUDA, but there is a couple
// of issues:
// * The include paths need to be set up such that all STD header files
// are taken from libcu++.
// * cuda:: namespace need to be forced for all std:: references.
//
// clang-format on

#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
#define FORTRAN_COMMON_REFERENCE_WRAPPER_H

#include "flang/Runtime/api-attrs.h"
#include <functional>
#include <type_traits>

#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
(defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
#endif

namespace Fortran::common {

template <class _Tp>
using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
template <class _Tp, class _Up>
struct __is_same_uncvref
: std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};

#if STD_REFERENCE_WRAPPER_UNSUPPORTED
template <class _Tp> class reference_wrapper {
public:
// types
typedef _Tp type;

private:
type *__f_;

static RT_API_ATTRS void __fun(_Tp &);
static void __fun(_Tp &&) = delete;

public:
template <class _Up,
class =
std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
decltype(__fun(std::declval<_Up>()))>>
constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
type &__f = static_cast<_Up &&>(__u);
__f_ = std::addressof(__f);
}

// access
constexpr RT_API_ATTRS operator type &() const { return *__f_; }
constexpr RT_API_ATTRS type &get() const { return *__f_; }

// invoke
template <class... _ArgTypes>
constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
operator()(_ArgTypes &&...__args) const {
return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
}
};

template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
return reference_wrapper<_Tp>(__t);
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
reference_wrapper<_Tp> __t) {
return __t;
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
const _Tp &__t) {
return reference_wrapper<const _Tp>(__t);
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
reference_wrapper<_Tp> __t) {
return __t;
}

template <class _Tp> void ref(const _Tp &&) = delete;
template <class _Tp> void cref(const _Tp &&) = delete;
#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
using std::cref;
using std::ref;
using std::reference_wrapper;
#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED

} // namespace Fortran::common

#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H
59 changes: 34 additions & 25 deletions flang/runtime/io-stmt.h
Expand Up @@ -17,6 +17,7 @@
#include "internal-unit.h"
#include "io-error.h"
#include "flang/Common/optional.h"
#include "flang/Common/reference-wrapper.h"
#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
Expand Down Expand Up @@ -210,39 +211,47 @@ class IoStatementState {
}

private:
std::variant<std::reference_wrapper<OpenStatementState>,
std::reference_wrapper<CloseStatementState>,
std::reference_wrapper<NoopStatementState>,
std::reference_wrapper<
std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
Fortran::common::reference_wrapper<CloseStatementState>,
Fortran::common::reference_wrapper<NoopStatementState>,
Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Input>>,
std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
InternalListIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
InternalListIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Input>>,
std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ExternalListIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
ExternalListIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Input>>,
std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ChildFormattedIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
ChildFormattedIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ChildListIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
ChildListIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Input>>,
std::reference_wrapper<InquireUnitState>,
std::reference_wrapper<InquireNoUnitState>,
std::reference_wrapper<InquireUnconnectedFileState>,
std::reference_wrapper<InquireIOLengthState>,
std::reference_wrapper<ExternalMiscIoStatementState>,
std::reference_wrapper<ErroneousIoStatementState>>
Fortran::common::reference_wrapper<InquireUnitState>,
Fortran::common::reference_wrapper<InquireNoUnitState>,
Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
Fortran::common::reference_wrapper<InquireIOLengthState>,
Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
Fortran::common::reference_wrapper<ErroneousIoStatementState>>
u_;
};

Expand Down

0 comments on commit d8f97c0

Please sign in to comment.