diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 29065c9f01ecf..ca14c144af2d6 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1227,18 +1227,24 @@ bool CheckForCoindexedObject(parser::ContextualMessages &, const std::optional &, const std::string &procName, const std::string &argName); -/// Check if any of the symbols part of the expression has a cuda data -/// attribute. -inline bool HasCUDAAttrs(const Expr &expr) { +// Get the number of distinct symbols with CUDA attribute in the expression. +template inline int GetNbOfCUDASymbols(const A &expr) { + semantics::UnorderedSymbolSet symbols; for (const Symbol &sym : CollectSymbols(expr)) { if (const auto *details = sym.GetUltimate().detailsIf()) { if (details->cudaDataAttr()) { - return true; + symbols.insert(sym); } } } - return false; + return symbols.size(); +} + +// Check if any of the symbols part of the expression has a CUDA data +// attribute. +template inline bool HasCUDAAttrs(const A &expr) { + return GetNbOfCUDASymbols(expr) > 0; } /// Check if the expression is a mix of host and device variables that require diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp index c0c6ff4c1a2ba..39bfc47a8eb1e 100644 --- a/flang/lib/Semantics/check-cuda.cpp +++ b/flang/lib/Semantics/check-cuda.cpp @@ -9,12 +9,14 @@ #include "check-cuda.h" #include "flang/Common/template.h" #include "flang/Evaluate/fold.h" +#include "flang/Evaluate/tools.h" #include "flang/Evaluate/traverse.h" #include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" #include "flang/Parser/tools.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/symbol.h" +#include "flang/Semantics/tools.h" // Once labeled DO constructs have been canonicalized and their parse subtrees // transformed into parser::DoConstructs, scan the parser::Blocks of the program @@ -413,4 +415,18 @@ void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) { } } +void CUDAChecker::Enter(const parser::AssignmentStmt &x) { + const evaluate::Assignment *assign{semantics::GetAssignment(x)}; + int nbLhs{evaluate::GetNbOfCUDASymbols(assign->lhs)}; + int nbRhs{evaluate::GetNbOfCUDASymbols(assign->rhs)}; + auto lhsLoc{std::get(x.t).GetSource()}; + + // device to host transfer with more than one device object on the rhs is not + // legal. + if (nbLhs == 0 && nbRhs > 1) { + context_.Say(lhsLoc, + "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US); + } +} + } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/check-cuda.h b/flang/lib/Semantics/check-cuda.h index d863795f16a7c..aa0cb46360bef 100644 --- a/flang/lib/Semantics/check-cuda.h +++ b/flang/lib/Semantics/check-cuda.h @@ -17,6 +17,7 @@ struct Program; class Messages; struct Name; class CharBlock; +struct AssignmentStmt; struct ExecutionPartConstruct; struct ExecutableConstruct; struct ActionStmt; @@ -38,6 +39,7 @@ class CUDAChecker : public virtual BaseChecker { void Enter(const parser::FunctionSubprogram &); void Enter(const parser::SeparateModuleSubprogram &); void Enter(const parser::CUFKernelDoConstruct &); + void Enter(const parser::AssignmentStmt &); private: SemanticsContext &context_; diff --git a/flang/test/Semantics/cuf11.cuf b/flang/test/Semantics/cuf11.cuf new file mode 100644 index 0000000000000..96108e2b24556 --- /dev/null +++ b/flang/test/Semantics/cuf11.cuf @@ -0,0 +1,12 @@ +! RUN: %python %S/test_errors.py %s %flang_fc1 + +subroutine sub1() + real, device :: adev(10), bdev(10) + real :: ahost(10) + +!ERROR: More than one reference to a CUDA object on the right hand side of the assigment + ahost = adev + bdev + + ahost = adev + adev + +end subroutine