12 changes: 7 additions & 5 deletions flang/lib/Lower/OpenMP/DataSharingProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H

#include "Clauses.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/OpenMP.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
Expand Down Expand Up @@ -52,7 +53,7 @@ class DataSharingProcessor {
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
Fortran::lower::AbstractConverter &converter;
fir::FirOpBuilder &firOpBuilder;
const Fortran::parser::OmpClauseList &opClauseList;
omp::List<omp::Clause> clauses;
Fortran::lower::pft::Evaluation &eval;
bool useDelayedPrivatization;
Fortran::lower::SymMap *symTable;
Expand All @@ -61,7 +62,7 @@ class DataSharingProcessor {
bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
void collectOmpObjectListSymbol(
const Fortran::parser::OmpObjectList &ompObjectList,
const omp::ObjectList &objects,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
void collectSymbolsForPrivatization();
void insertBarrier();
Expand All @@ -81,14 +82,15 @@ class DataSharingProcessor {

public:
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
const Fortran::parser::OmpClauseList &opClauseList,
Fortran::lower::pft::Evaluation &eval,
bool useDelayedPrivatization = false,
Fortran::lower::SymMap *symTable = nullptr)
: hasLastPrivateOp(false), converter(converter),
firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
eval(eval), useDelayedPrivatization(useDelayedPrivatization),
symTable(symTable) {}
firOpBuilder(converter.getFirOpBuilder()),
clauses(omp::makeList(opClauseList, semaCtx)), eval(eval),
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {}

// Privatisation is split into two steps.
// Step1 performs cloning of all privatisation clauses and copying for
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
std::optional<DataSharingProcessor> tempDsp;
if (privatize) {
if (!info.dsp) {
tempDsp.emplace(info.converter, *info.clauses, info.eval);
tempDsp.emplace(info.converter, info.semaCtx, *info.clauses, info.eval);
tempDsp->processStep1();
}
}
Expand Down Expand Up @@ -627,7 +627,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
}

bool privatize = !outerCombined;
DataSharingProcessor dsp(converter, clauseList, eval,
DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
/*useDelayedPrivatization=*/true, &symTable);

if (privatize)
Expand Down Expand Up @@ -1575,7 +1575,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &loopOpClauseList,
mlir::Location loc) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
DataSharingProcessor dsp(converter, loopOpClauseList, eval);
DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval);
dsp.processStep1();

Fortran::lower::StatementContext stmtCtx;
Expand Down Expand Up @@ -1634,7 +1634,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList *endClauseList,
mlir::Location loc) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
DataSharingProcessor dsp(converter, beginClauseList, eval);
DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
dsp.processStep1();

Fortran::lower::StatementContext stmtCtx;
Expand Down
45 changes: 25 additions & 20 deletions flang/runtime/io-stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,75 +467,78 @@ int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
}

Fortran::common::optional<DataEdit> IoStatementState::GetNextDataEdit(int n) {
return visitIo(
return common::visit(
[&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
}

bool IoStatementState::Emit(
const char *data, std::size_t bytes, std::size_t elementBytes) {
return visitIo(
return common::visit(
[=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_);
}

bool IoStatementState::Receive(
char *data, std::size_t n, std::size_t elementBytes) {
return visitIo(
return common::visit(
[=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_);
}

std::size_t IoStatementState::GetNextInputBytes(const char *&p) {
return visitIo([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
return common::visit(
[&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
}

bool IoStatementState::AdvanceRecord(int n) {
return visitIo([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
}

void IoStatementState::BackspaceRecord() {
visitIo([](auto &x) { x.get().BackspaceRecord(); }, u_);
common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
}

void IoStatementState::HandleRelativePosition(std::int64_t n) {
visitIo([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
}

void IoStatementState::HandleAbsolutePosition(std::int64_t n) {
visitIo([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
}

void IoStatementState::CompleteOperation() {
visitIo([](auto &x) { x.get().CompleteOperation(); }, u_);
common::visit([](auto &x) { x.get().CompleteOperation(); }, u_);
}

int IoStatementState::EndIoStatement() {
return visitIo([](auto &x) { return x.get().EndIoStatement(); }, u_);
return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
}

ConnectionState &IoStatementState::GetConnectionState() {
return visitIo(
return common::visit(
[](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); },
u_);
}

MutableModes &IoStatementState::mutableModes() {
return visitIo(
return common::visit(
[](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_);
}

bool IoStatementState::BeginReadingRecord() {
return visitIo([](auto &x) { return x.get().BeginReadingRecord(); }, u_);
return common::visit(
[](auto &x) { return x.get().BeginReadingRecord(); }, u_);
}

IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
return visitIo(
return common::visit(
[](auto &x) -> IoErrorHandler & {
return static_cast<IoErrorHandler &>(x.get());
},
u_);
}

ExternalFileUnit *IoStatementState::GetExternalFileUnit() const {
return visitIo([](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
return common::visit(
[](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
}

Fortran::common::optional<char32_t> IoStatementState::GetCurrentChar(
Expand Down Expand Up @@ -661,26 +664,28 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) {

bool IoStatementState::Inquire(
InquiryKeywordHash inquiry, char *out, std::size_t chars) {
return visitIo(
return common::visit(
[&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_);
}

bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) {
return visitIo([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
return common::visit(
[&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
}

bool IoStatementState::Inquire(
InquiryKeywordHash inquiry, std::int64_t id, bool &out) {
return visitIo(
return common::visit(
[&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
}

bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) {
return visitIo([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
return common::visit(
[&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
}

std::int64_t IoStatementState::InquirePos() {
return visitIo([&](auto &x) { return x.get().InquirePos(); }, u_);
return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_);
}

void IoStatementState::GotChar(int n) {
Expand Down
92 changes: 5 additions & 87 deletions flang/runtime/io-stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@
#include "format.h"
#include "internal-unit.h"
#include "io-error.h"
#include "flang/Common/idioms.h"
#include "flang/Common/optional.h"
#include "flang/Common/reference-wrapper.h"
#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
#include <functional>
#include <type_traits>
#include <variant>

// I/O statement state classes that may be instantiated during execution
// on an offload device have this trait:
CLASS_TRAIT(AvailableOnDevice)

namespace Fortran::runtime::io {

class ExternalFileUnit;
Expand Down Expand Up @@ -56,23 +52,15 @@ template <Direction> class ChildListIoStatementState;
template <Direction> class ChildUnformattedIoStatementState;

struct InputStatementState {};
struct OutputStatementState {
using AvailableOnDevice = std::true_type;
};
struct OutputStatementState {};
template <Direction D>
using IoDirectionState = std::conditional_t<D == Direction::Input,
InputStatementState, OutputStatementState>;

// Common state for all kinds of formatted I/O
template <Direction D> class FormattedIoStatementState {};
template <> class FormattedIoStatementState<Direction::Output> {
public:
using AvailableOnDevice = std::true_type;
};

template <> class FormattedIoStatementState<Direction::Input> {
public:
using AvailableOnDevice = std::true_type;
std::size_t GetEditDescriptorChars() const;
void GotChar(int);

Expand Down Expand Up @@ -125,19 +113,10 @@ class IoStatementState {

// N.B.: this also works with base classes
template <typename A> A *get_if() const {
[[maybe_unused]] std::size_t index{u_.index()};
return Fortran::common::visit(
[=](auto &x) -> A * {
return common::visit(
[](auto &x) -> A * {
if constexpr (std::is_convertible_v<decltype(x.get()), A &>) {
#if defined(RT_DEVICE_COMPILATION)
if constexpr (!AvailableOnDevice<std::decay_t<A>>) {
terminateOnDevice(__FILE__, __LINE__, index);
} else {
#endif
return &x.get();
#if defined(RT_DEVICE_COMPILATION)
}
#endif
return &x.get();
}
return nullptr;
},
Expand Down Expand Up @@ -232,40 +211,6 @@ class IoStatementState {
}

private:
#if RT_DEVICE_COMPILATION
static RT_API_ATTRS void terminateOnDevice(
const char *sourceFile, int sourceLine, std::size_t index) {
// %zd is not supported by device printf.
Terminator{sourceFile, sourceLine}.Crash(
"Unexpected IO statement variant (index %d) during device execution",
static_cast<int>(index));
}
#endif

// Define special visitor for the variants of IoStatementState.
// During the device code compilation the visitor only allows
// visiting those variants that have AvailableOnDevice trait
// are supported on the device.
template <typename VISITOR, typename VARIANT>
static inline RT_API_ATTRS auto visitIo(VISITOR &&visitor, VARIANT &&u)
-> decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))) {
using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u))));
[[maybe_unused]] std::size_t index{u.index()};
return Fortran::common::visit(
[&](auto &x) -> Result {
#if defined(RT_DEVICE_COMPILATION)
if constexpr (!AvailableOnDevice<std::decay_t<decltype(x.get())>>) {
terminateOnDevice(__FILE__, __LINE__, index);
} else {
#endif
return visitor(x);
#if defined(RT_DEVICE_COMPILATION)
}
#endif
},
std::forward<VARIANT>(u));
}

std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
Fortran::common::reference_wrapper<CloseStatementState>,
Fortran::common::reference_wrapper<NoopStatementState>,
Expand Down Expand Up @@ -351,7 +296,6 @@ template <>
class ListDirectedStatementState<Direction::Output>
: public FormattedIoStatementState<Direction::Output> {
public:
using AvailableOnDevice = std::true_type;
bool EmitLeadingSpaceOrAdvance(
IoStatementState &, std::size_t = 1, bool isCharacter = false);
Fortran::common::optional<DataEdit> GetNextDataEdit(
Expand All @@ -370,7 +314,6 @@ template <>
class ListDirectedStatementState<Direction::Input>
: public FormattedIoStatementState<Direction::Input> {
public:
using AvailableOnDevice = std::false_type;
bool inNamelistSequence() const { return inNamelistSequence_; }
int EndIoStatement();

Expand Down Expand Up @@ -408,8 +351,6 @@ template <Direction DIR>
class InternalIoStatementState : public IoStatementBase,
public IoDirectionState<DIR> {
public:
using AvailableOnDevice = std::conditional_t<DIR == Direction::Output,
std::true_type, std::false_type>;
using Buffer =
std::conditional_t<DIR == Direction::Input, const char *, char *>;
InternalIoStatementState(Buffer, std::size_t,
Expand Down Expand Up @@ -438,8 +379,6 @@ class InternalFormattedIoStatementState
: public InternalIoStatementState<DIR>,
public FormattedIoStatementState<DIR> {
public:
using AvailableOnDevice = std::conditional_t<DIR == Direction::Output,
std::true_type, std::false_type>;
using CharType = CHAR;
using typename InternalIoStatementState<DIR>::Buffer;
InternalFormattedIoStatementState(Buffer internal, std::size_t internalLength,
Expand Down Expand Up @@ -468,8 +407,6 @@ template <Direction DIR>
class InternalListIoStatementState : public InternalIoStatementState<DIR>,
public ListDirectedStatementState<DIR> {
public:
using AvailableOnDevice = std::conditional_t<DIR == Direction::Output,
std::true_type, std::false_type>;
using typename InternalIoStatementState<DIR>::Buffer;
InternalListIoStatementState(Buffer internal, std::size_t internalLength,
const char *sourceFile = nullptr, int sourceLine = 0);
Expand All @@ -487,7 +424,6 @@ class InternalListIoStatementState : public InternalIoStatementState<DIR>,

class ExternalIoStatementBase : public IoStatementBase {
public:
using AvailableOnDevice = std::false_type;
ExternalIoStatementBase(
ExternalFileUnit &, const char *sourceFile = nullptr, int sourceLine = 0);
ExternalFileUnit &unit() { return unit_; }
Expand All @@ -508,7 +444,6 @@ template <Direction DIR>
class ExternalIoStatementState : public ExternalIoStatementBase,
public IoDirectionState<DIR> {
public:
using AvailableOnDevice = std::false_type;
ExternalIoStatementState(
ExternalFileUnit &, const char *sourceFile = nullptr, int sourceLine = 0);
MutableModes &mutableModes() { return mutableModes_; }
Expand All @@ -535,7 +470,6 @@ class ExternalFormattedIoStatementState
: public ExternalIoStatementState<DIR>,
public FormattedIoStatementState<DIR> {
public:
using AvailableOnDevice = std::false_type;
using CharType = CHAR;
ExternalFormattedIoStatementState(ExternalFileUnit &, const CharType *format,
std::size_t formatLength, const Descriptor *formatDescriptor = nullptr,
Expand All @@ -555,7 +489,6 @@ template <Direction DIR>
class ExternalListIoStatementState : public ExternalIoStatementState<DIR>,
public ListDirectedStatementState<DIR> {
public:
using AvailableOnDevice = std::false_type;
using ExternalIoStatementState<DIR>::ExternalIoStatementState;
using ListDirectedStatementState<DIR>::GetNextDataEdit;
int EndIoStatement();
Expand All @@ -565,7 +498,6 @@ template <Direction DIR>
class ExternalUnformattedIoStatementState
: public ExternalIoStatementState<DIR> {
public:
using AvailableOnDevice = std::false_type;
using ExternalIoStatementState<DIR>::ExternalIoStatementState;
bool Receive(char *, std::size_t, std::size_t elementBytes = 0);
};
Expand All @@ -574,7 +506,6 @@ template <Direction DIR>
class ChildIoStatementState : public IoStatementBase,
public IoDirectionState<DIR> {
public:
using AvailableOnDevice = std::false_type;
ChildIoStatementState(
ChildIo &, const char *sourceFile = nullptr, int sourceLine = 0);
ChildIo &child() { return child_; }
Expand All @@ -595,7 +526,6 @@ template <Direction DIR, typename CHAR>
class ChildFormattedIoStatementState : public ChildIoStatementState<DIR>,
public FormattedIoStatementState<DIR> {
public:
using AvailableOnDevice = std::false_type;
using CharType = CHAR;
ChildFormattedIoStatementState(ChildIo &, const CharType *format,
std::size_t formatLength, const Descriptor *formatDescriptor = nullptr,
Expand All @@ -618,7 +548,6 @@ template <Direction DIR>
class ChildListIoStatementState : public ChildIoStatementState<DIR>,
public ListDirectedStatementState<DIR> {
public:
using AvailableOnDevice = std::false_type;
using ChildIoStatementState<DIR>::ChildIoStatementState;
using ListDirectedStatementState<DIR>::GetNextDataEdit;
int EndIoStatement();
Expand All @@ -627,15 +556,13 @@ class ChildListIoStatementState : public ChildIoStatementState<DIR>,
template <Direction DIR>
class ChildUnformattedIoStatementState : public ChildIoStatementState<DIR> {
public:
using AvailableOnDevice = std::false_type;
using ChildIoStatementState<DIR>::ChildIoStatementState;
bool Receive(char *, std::size_t, std::size_t elementBytes = 0);
};

// OPEN
class OpenStatementState : public ExternalIoStatementBase {
public:
using AvailableOnDevice = std::false_type;
OpenStatementState(ExternalFileUnit &unit, bool wasExtant, bool isNewUnit,
const char *sourceFile = nullptr, int sourceLine = 0)
: ExternalIoStatementBase{unit, sourceFile, sourceLine},
Expand Down Expand Up @@ -667,7 +594,6 @@ class OpenStatementState : public ExternalIoStatementBase {

class CloseStatementState : public ExternalIoStatementBase {
public:
using AvailableOnDevice = std::false_type;
CloseStatementState(ExternalFileUnit &unit, const char *sourceFile = nullptr,
int sourceLine = 0)
: ExternalIoStatementBase{unit, sourceFile, sourceLine} {}
Expand All @@ -682,7 +608,6 @@ class CloseStatementState : public ExternalIoStatementBase {
// and recoverable BACKSPACE(bad unit)
class NoUnitIoStatementState : public IoStatementBase {
public:
using AvailableOnDevice = std::false_type;
IoStatementState &ioStatementState() { return ioStatementState_; }
MutableModes &mutableModes() { return connection_.modes; }
ConnectionState &GetConnectionState() { return connection_; }
Expand All @@ -705,7 +630,6 @@ class NoUnitIoStatementState : public IoStatementBase {

class NoopStatementState : public NoUnitIoStatementState {
public:
using AvailableOnDevice = std::false_type;
NoopStatementState(
const char *sourceFile = nullptr, int sourceLine = 0, int unitNumber = -1)
: NoUnitIoStatementState{*this, sourceFile, sourceLine, unitNumber} {}
Expand Down Expand Up @@ -750,7 +674,6 @@ extern template class FormatControl<

class InquireUnitState : public ExternalIoStatementBase {
public:
using AvailableOnDevice = std::false_type;
InquireUnitState(ExternalFileUnit &unit, const char *sourceFile = nullptr,
int sourceLine = 0);
bool Inquire(InquiryKeywordHash, char *, std::size_t);
Expand All @@ -761,7 +684,6 @@ class InquireUnitState : public ExternalIoStatementBase {

class InquireNoUnitState : public NoUnitIoStatementState {
public:
using AvailableOnDevice = std::false_type;
InquireNoUnitState(const char *sourceFile = nullptr, int sourceLine = 0,
int badUnitNumber = -1);
bool Inquire(InquiryKeywordHash, char *, std::size_t);
Expand All @@ -772,7 +694,6 @@ class InquireNoUnitState : public NoUnitIoStatementState {

class InquireUnconnectedFileState : public NoUnitIoStatementState {
public:
using AvailableOnDevice = std::false_type;
InquireUnconnectedFileState(OwningPtr<char> &&path,
const char *sourceFile = nullptr, int sourceLine = 0);
bool Inquire(InquiryKeywordHash, char *, std::size_t);
Expand All @@ -787,7 +708,6 @@ class InquireUnconnectedFileState : public NoUnitIoStatementState {
class InquireIOLengthState : public NoUnitIoStatementState,
public OutputStatementState {
public:
using AvailableOnDevice = std::false_type;
InquireIOLengthState(const char *sourceFile = nullptr, int sourceLine = 0);
std::size_t bytes() const { return bytes_; }
bool Emit(const char *, std::size_t bytes, std::size_t elementBytes = 0);
Expand All @@ -798,7 +718,6 @@ class InquireIOLengthState : public NoUnitIoStatementState,

class ExternalMiscIoStatementState : public ExternalIoStatementBase {
public:
using AvailableOnDevice = std::false_type;
enum Which { Flush, Backspace, Endfile, Rewind, Wait };
ExternalMiscIoStatementState(ExternalFileUnit &unit, Which which,
const char *sourceFile = nullptr, int sourceLine = 0)
Expand All @@ -812,7 +731,6 @@ class ExternalMiscIoStatementState : public ExternalIoStatementBase {

class ErroneousIoStatementState : public IoStatementBase {
public:
using AvailableOnDevice = std::false_type;
explicit ErroneousIoStatementState(Iostat iostat,
ExternalFileUnit *unit = nullptr, const char *sourceFile = nullptr,
int sourceLine = 0)
Expand Down
1 change: 1 addition & 0 deletions llvm/utils/gn/secondary/clang/lib/Headers/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ copy("Headers") {
"ppc_wrappers/xmmintrin.h",
"prfchiintrin.h",
"prfchwintrin.h",
"ptrauth.h",
"ptwriteintrin.h",
"raointintrin.h",
"rdpruintrin.h",
Expand Down