Skip to content

Commit

Permalink
[flang] CUDA Fortran - part 5/5: statement semantics
Browse files Browse the repository at this point in the history
Canonicalize !$CUF KERNEL DO loop nests, similar to OpenACC/OpenMP
canonicalization.  Check statements and expressions in device contexts
for usage that isn't supported.  Add more tests, and include some
tweaks to standard modules needed to build CUDA Fortran modules.

Depends on https://reviews.llvm.org/D150159,
https://reviews.llvm.org/D150161, https://reviews.llvm.org/D150162, &
https://reviews.llvm.org/D150163.

Differential Revision: https://reviews.llvm.org/D150164
  • Loading branch information
klausler committed Jun 1, 2023
1 parent 460d136 commit f674ddc
Show file tree
Hide file tree
Showing 22 changed files with 848 additions and 46 deletions.
29 changes: 24 additions & 5 deletions flang/include/flang/Evaluate/traverse.h
Expand Up @@ -38,6 +38,7 @@
// expression of an ASSOCIATE (or related) construct entity.

#include "expression.h"
#include "flang/Common/indirection.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
#include <set>
Expand All @@ -53,6 +54,10 @@ template <typename Visitor, typename Result> class Traverse {
Result operator()(const common::Indirection<A, C> &x) const {
return visitor_(x.value());
}
template <typename A>
Result operator()(const common::ForwardOwningPointer<A> &p) const {
return visitor_(p.get());
}
template <typename _> Result operator()(const SymbolRef x) const {
return visitor_(*x);
}
Expand All @@ -76,13 +81,17 @@ template <typename Visitor, typename Result> class Traverse {
return visitor_.Default();
}
}
template <typename... A>
Result operator()(const std::variant<A...> &u) const {
return common::visit(visitor_, u);
template <typename... As>
Result operator()(const std::variant<As...> &u) const {
return common::visit([=](const auto &y) { return visitor_(y); }, u);
}
template <typename A> Result operator()(const std::vector<A> &x) const {
return CombineContents(x);
}
template <typename A, typename B>
Result operator()(const std::pair<A, B> &x) const {
return Combine(x.first, x.second);
}

// Leaves
Result operator()(const BOZLiteralConstant &) const {
Expand Down Expand Up @@ -233,14 +242,24 @@ template <typename Visitor, typename Result> class Traverse {
template <typename T> Result operator()(const Expr<T> &x) const {
return visitor_(x.u);
}
Result operator()(const Assignment &x) const {
return Combine(x.lhs, x.rhs, x.u);
}
Result operator()(const Assignment::Intrinsic &) const {
return visitor_.Default();
}
Result operator()(const GenericExprWrapper &x) const { return visitor_(x.v); }
Result operator()(const GenericAssignmentWrapper &x) const {
return visitor_(x.v);
}

private:
template <typename ITER> Result CombineRange(ITER iter, ITER end) const {
if (iter == end) {
return visitor_.Default();
} else {
Result result{visitor_(*iter++)};
for (; iter != end; ++iter) {
Result result{visitor_(*iter)};
for (++iter; iter != end; ++iter) {
result = visitor_.Combine(std::move(result), visitor_(*iter));
}
return result;
Expand Down
127 changes: 127 additions & 0 deletions flang/include/flang/Parser/tools.h
Expand Up @@ -65,6 +65,18 @@ struct UnwrapperHelper {
return common::visit([](const auto &y) { return Unwrap<A>(y); }, x);
}

template <typename A, std::size_t J = 0, typename... Bs>
static const A *Unwrap(const std::tuple<Bs...> &x) {
if constexpr (J < sizeof...(Bs)) {
if (auto result{Unwrap<A>(std::get<J>(x))}) {
return result;
}
return Unwrap<A, (J + 1)>(x);
} else {
return nullptr;
}
}

template <typename A, typename B>
static const A *Unwrap(const std::optional<B> &o) {
if (o) {
Expand Down Expand Up @@ -122,5 +134,120 @@ template <typename A, typename = int> struct HasTypedExpr : std::false_type {};
template <typename A>
struct HasTypedExpr<A, decltype(static_cast<void>(A::typedExpr), 0)>
: std::true_type {};

// GetSource()

template <bool GET_FIRST> struct GetSourceHelper {

using Result = std::optional<CharBlock>;

template <typename A> static Result GetSource(A *p) {
if (p) {
return GetSource(*p);
} else {
return std::nullopt;
}
}
template <typename A>
static Result GetSource(const common::Indirection<A> &x) {
return GetSource(x.value());
}

template <typename A, bool COPY>
static Result GetSource(const common::Indirection<A, COPY> &x) {
return GetSource(x.value());
}

template <typename... As>
static Result GetSource(const std::variant<As...> &x) {
return common::visit([](const auto &y) { return GetSource(y); }, x);
}

template <std::size_t J = 0, typename... As>
static Result GetSource(const std::tuple<As...> &x) {
if constexpr (J < sizeof...(As)) {
constexpr std::size_t index{GET_FIRST ? J : sizeof...(As) - J - 1};
if (auto result{GetSource(std::get<index>(x))}) {
return result;
}
return GetSource<(J + 1)>(x);
} else {
return {};
}
}

template <typename A> static Result GetSource(const std::optional<A> &o) {
if (o) {
return GetSource(*o);
} else {
return {};
}
}

template <typename A> static Result GetSource(const std::list<A> &x) {
if constexpr (GET_FIRST) {
for (const A &y : x) {
if (auto result{GetSource(y)}) {
return result;
}
}
} else {
for (auto iter{x.rbegin()}; iter != x.rend(); ++iter) {
if (auto result{GetSource(*iter)}) {
return result;
}
}
}
return {};
}

template <typename A> static Result GetSource(const std::vector<A> &x) {
if constexpr (GET_FIRST) {
for (const A &y : x) {
if (auto result{GetSource(y)}) {
return result;
}
}
} else {
for (auto iter{x.rbegin()}; iter != x.rend(); ++iter) {
if (auto result{GetSource(*iter)}) {
return result;
}
}
}
return {};
}

template <typename A> static Result GetSource(A &x) {
if constexpr (HasSource<A>::value) {
return x.source;
} else if constexpr (ConstraintTrait<A>) {
return GetSource(x.thing);
} else if constexpr (WrapperTrait<A>) {
return GetSource(x.v);
} else if constexpr (UnionTrait<A>) {
return GetSource(x.u);
} else if constexpr (TupleTrait<A>) {
return GetSource(x.t);
} else {
return {};
}
}
};

template <typename A> std::optional<CharBlock> GetSource(const A &x) {
return GetSourceHelper<true>::GetSource(x);
}
template <typename A> std::optional<CharBlock> GetSource(A &x) {
return GetSourceHelper<true>::GetSource(const_cast<const A &>(x));
}

template <typename A> std::optional<CharBlock> GetLastSource(const A &x) {
return GetSourceHelper<false>::GetSource(x);
}
template <typename A> std::optional<CharBlock> GetLastSource(A &x) {
return GetSourceHelper<false>::GetSource(const_cast<const A &>(x));
}

} // namespace Fortran::parser
#endif // FORTRAN_PARSER_TOOLS_H_
5 changes: 3 additions & 2 deletions flang/include/flang/Semantics/semantics.h
Expand Up @@ -214,8 +214,9 @@ class SemanticsContext {
// Defines builtinsScope_ from the __Fortran_builtins module
void UseFortranBuiltinsModule();
const Scope *GetBuiltinsScope() const { return builtinsScope_; }

void UsePPCFortranBuiltinTypesModule();
const Scope *GetCUDABuiltinsScope();
const Scope &GetCUDABuiltinsScope();
void UsePPCFortranBuiltinsModule();
Scope *GetPPCBuiltinTypesScope() { return ppcBuiltinTypesScope_; }
const Scope *GetPPCBuiltinsScope() const { return ppcBuiltinsScope_; }
Expand Down Expand Up @@ -281,7 +282,7 @@ class SemanticsContext {
std::set<std::string> tempNames_;
const Scope *builtinsScope_{nullptr}; // module __Fortran_builtins
Scope *ppcBuiltinTypesScope_{nullptr}; // module __Fortran_PPC_types
std::optional<const Scope *> CUDABuiltinsScope_; // module __CUDA_builtins
std::optional<const Scope *> cudaBuiltinsScope_; // module __CUDA_builtins
const Scope *ppcBuiltinsScope_{nullptr}; // module __Fortran_PPC_intrinsics
std::list<parser::Program> modFileParseTrees_;
std::unique_ptr<CommonBlockMap> commonBlockMap_;
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Parser/unparse.cpp
Expand Up @@ -1698,7 +1698,7 @@ class UnparseVisitor {
Put('('), Walk(std::get<std::list<ActualArgSpec>>(x.v.t), ", "), Put(')');
}
void Unparse(const CallStmt &x) { // R1521
if (asFortran_ && x.typedCall.get() && !x.chevrons /*CUDA todo*/) {
if (asFortran_ && x.typedCall.get()) {
Put(' ');
asFortran_->call(out_, *x.typedCall);
Put('\n');
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Semantics/CMakeLists.txt
Expand Up @@ -10,6 +10,7 @@ add_flang_library(FortranSemantics
check-call.cpp
check-case.cpp
check-coarray.cpp
check-cuda.cpp
check-data.cpp
check-deallocate.cpp
check-declarations.cpp
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Semantics/canonicalize-acc.cpp
Expand Up @@ -65,7 +65,7 @@ class CanonicalizationOfAcc {

const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
if (outer->IsDoConcurrent()) {
return; // Tile is not allowed on DO CONURRENT
return; // Tile is not allowed on DO CONCURRENT
}
for (const parser::DoConstruct *loop{&*outer}; loop && tileArgNb > 0;
--tileArgNb) {
Expand Down
31 changes: 24 additions & 7 deletions flang/lib/Semantics/check-allocate.cpp
Expand Up @@ -31,6 +31,8 @@ struct AllocateCheckerInfo {
bool gotTypeSpec{false};
bool gotSource{false};
bool gotMold{false};
bool gotStream{false};
bool gotPinned{false};
};

class AllocationCheckerHelper {
Expand Down Expand Up @@ -179,8 +181,22 @@ static std::optional<AllocateCheckerInfo> CheckAllocateOptions(
parserSourceExpr = &mold.v.value();
info.gotMold = true;
},
[](const parser::AllocOpt::Stream &) { /* CUDA coming */ },
[](const parser::AllocOpt::Pinned &) { /* CUDA coming */ },
[&](const parser::AllocOpt::Stream &stream) { // CUDA
if (info.gotStream) {
context.Say(
"STREAM may not be duplicated in a ALLOCATE statement"_err_en_US);
stopCheckingAllocate = true;
}
info.gotStream = true;
},
[&](const parser::AllocOpt::Pinned &pinned) { // CUDA
if (info.gotPinned) {
context.Say(
"PINNED may not be duplicated in a ALLOCATE statement"_err_en_US);
stopCheckingAllocate = true;
}
info.gotPinned = true;
},
},
allocOpt.u);
}
Expand Down Expand Up @@ -569,12 +585,13 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) {
return false;
}
context.CheckIndexVarRedefine(name_);
const Scope &subpScope{
GetProgramUnitContaining(context.FindScope(name_.source))};
if (allocateObject_.typedExpr && allocateObject_.typedExpr->v) {
if (auto whyNot{
WhyNotDefinable(name_.source, context.FindScope(name_.source),
{DefinabilityFlag::PointerDefinition,
DefinabilityFlag::AcceptAllocatable},
*allocateObject_.typedExpr->v)}) {
if (auto whyNot{WhyNotDefinable(name_.source, subpScope,
{DefinabilityFlag::PointerDefinition,
DefinabilityFlag::AcceptAllocatable},
*allocateObject_.typedExpr->v)}) {
context
.Say(name_.source,
"Name in ALLOCATE statement is not definable"_err_en_US)
Expand Down

0 comments on commit f674ddc

Please sign in to comment.