Skip to content

Commit

Permalink
[flang] Enforce accessibility requirement on type-bound generic opera…
Browse files Browse the repository at this point in the history
…tors, &c.

Type-bound generics like operator(+) and assignment(=) need to not be
PRIVATE if they are used outside the module in which they are declared.

Differential Revision: https://reviews.llvm.org/D139123
  • Loading branch information
klausler committed Dec 3, 2022
1 parent 00272ad commit d7a1351
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 92 deletions.
1 change: 1 addition & 0 deletions flang/include/flang/Semantics/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Scope {
const Scope *GetDerivedTypeParent() const;
const Scope &GetDerivedTypeBase() const;
inline std::optional<SourceName> GetName() const;
// Returns true if this scope contains, or is, another scope.
bool Contains(const Scope &) const;
/// Make a scope nested in this one
Scope &MakeScope(Kind kind, Symbol *symbol = nullptr);
Expand Down
10 changes: 8 additions & 2 deletions flang/include/flang/Semantics/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ bool IsIntrinsicConcat(
bool IsGenericDefinedOp(const Symbol &);
bool IsDefinedOperator(SourceName);
std::string MakeOpName(SourceName);

// Returns true if maybeAncestor exists and is a proper ancestor of a
// descendent scope (or symbol owner). Will be false, unlike Scope::Contains(),
// if maybeAncestor *is* the descendent.
bool DoesScopeContain(const Scope *maybeAncestor, const Scope &maybeDescendent);
bool DoesScopeContain(const Scope *, const Symbol &);

bool IsUseAssociated(const Symbol &, const Scope &);
bool IsHostAssociated(const Symbol &, const Scope &);
bool IsHostAssociatedIntoSubprogram(const Symbol &, const Scope &);
Expand Down Expand Up @@ -182,8 +187,9 @@ bool HasCoarray(const parser::Expr &);
bool IsAssumedType(const Symbol &);
bool IsPolymorphic(const Symbol &);
bool IsPolymorphicAllocatable(const Symbol &);
// Return an error if component symbol is not accessible from scope (7.5.4.8(2))
std::optional<parser::MessageFormattedText> CheckAccessibleComponent(

// Return an error if a symbol is not accessible from a scope
std::optional<parser::MessageFormattedText> CheckAccessibleSymbol(
const semantics::Scope &, const Symbol &);

// Analysis of image control statements
Expand Down
171 changes: 105 additions & 66 deletions flang/lib/Semantics/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ class ArgumentAnalyzer {

// Find and return a user-defined operator or report an error.
// The provided message is used if there is no such operator.
MaybeExpr TryDefinedOp(const char *, parser::MessageFixedText,
const Symbol **definedOpSymbolPtr = nullptr, bool isUserOp = false);
// If a definedOpSymbolPtr is provided, the caller must check
// for its accessibility.
MaybeExpr TryDefinedOp(
const char *, parser::MessageFixedText, bool isUserOp = false);
template <typename E>
MaybeExpr TryDefinedOp(E opr, parser::MessageFixedText msg) {
return TryDefinedOp(
Expand All @@ -175,7 +177,7 @@ class ArgumentAnalyzer {
MaybeExpr AnalyzeExprOrWholeAssumedSizeArray(const parser::Expr &);
bool AreConformable() const;
const Symbol *FindBoundOp(parser::CharBlock, int passIndex,
const Symbol *&definedOp, bool isSubroutine);
const Symbol *&generic, bool isSubroutine);
void AddAssignmentConversion(
const DynamicType &lhsType, const DynamicType &rhsType);
bool OkLogicalIntegerAssignment(TypeCategory lhs, TypeCategory rhs);
Expand Down Expand Up @@ -1778,10 +1780,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(
}
}
if (symbol) {
if (const auto *currScope{context_.globalScope().FindScope(source)}) {
if (auto msg{CheckAccessibleComponent(*currScope, *symbol)}) {
Say(source, *msg);
}
const semantics::Scope &innermost{context_.FindScope(expr.source)};
if (auto msg{CheckAccessibleSymbol(innermost, *symbol)}) {
Say(expr.source, std::move(*msg));
}
if (checkConflicts) {
auto componentIter{
Expand Down Expand Up @@ -1809,7 +1810,6 @@ MaybeExpr ExpressionAnalyzer::Analyze(
}
unavailable.insert(symbol->name());
if (value) {
const auto &innermost{context_.FindScope(expr.source)};
if (symbol->has<semantics::ProcEntityDetails>()) {
CHECK(IsPointer(*symbol));
} else if (symbol->has<semantics::ObjectEntityDetails>()) {
Expand Down Expand Up @@ -2869,7 +2869,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) {
ArgumentAnalyzer analyzer{*this, name.source};
analyzer.Analyze(std::get<1>(x.t));
return analyzer.TryDefinedOp(name.source.ToString().c_str(),
"No operator %s defined for %s"_err_en_US, nullptr, true);
"No operator %s defined for %s"_err_en_US, true);
}

// Binary (dyadic) operations
Expand Down Expand Up @@ -3053,7 +3053,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedBinary &x) {
analyzer.Analyze(std::get<1>(x.t));
analyzer.Analyze(std::get<2>(x.t));
return analyzer.TryDefinedOp(name.source.ToString().c_str(),
"No operator %s defined for %s and %s"_err_en_US, nullptr, true);
"No operator %s defined for %s and %s"_err_en_US, true);
}

// Returns true if a parsed function reference should be converted
Expand Down Expand Up @@ -3635,63 +3635,100 @@ bool ArgumentAnalyzer::CheckForNullPointer(const char *where) {
return true;
}

MaybeExpr ArgumentAnalyzer::TryDefinedOp(const char *opr,
parser::MessageFixedText error, const Symbol **definedOpSymbolPtr,
bool isUserOp) {
MaybeExpr ArgumentAnalyzer::TryDefinedOp(
const char *opr, parser::MessageFixedText error, bool isUserOp) {
if (AnyUntypedOrMissingOperand()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
return std::nullopt;
}
const Symbol *localDefinedOpSymbolPtr{nullptr};
if (!definedOpSymbolPtr) {
definedOpSymbolPtr = &localDefinedOpSymbolPtr;
}
MaybeExpr result;
bool anyPossibilities{false};
std::optional<parser::MessageFormattedText> inaccessible;
std::vector<const Symbol *> hit;
std::string oprNameString{
isUserOp ? std::string{opr} : "operator("s + opr + ')'};
parser::CharBlock oprName{oprNameString};
{
auto restorer{context_.GetContextualMessages().DiscardMessages()};
std::string oprNameString{
isUserOp ? std::string{opr} : "operator("s + opr + ')'};
parser::CharBlock oprName{oprNameString};
const auto &scope{context_.context().FindScope(source_)};
if (Symbol *symbol{scope.FindSymbol(oprName)}) {
*definedOpSymbolPtr = symbol;
anyPossibilities = true;
parser::Name name{symbol->name(), symbol};
if (auto result{context_.AnalyzeDefinedOp(name, GetActuals())}) {
return result;
result = context_.AnalyzeDefinedOp(name, GetActuals());
if (result) {
inaccessible = CheckAccessibleSymbol(scope, *symbol);
if (inaccessible) {
result.reset();
} else {
hit.push_back(symbol);
}
}
}
for (std::size_t passIndex{0}; passIndex < actuals_.size(); ++passIndex) {
if (const Symbol *
symbol{FindBoundOp(oprName, passIndex, *definedOpSymbolPtr, false)}) {
if (MaybeExpr result{TryBoundOp(*symbol, passIndex)}) {
return result;
const Symbol *generic{nullptr};
if (const Symbol *binding{
FindBoundOp(oprName, passIndex, generic, false)}) {
anyPossibilities = true;
if (MaybeExpr thisResult{TryBoundOp(*binding, passIndex)}) {
if (auto thisInaccessible{
CheckAccessibleSymbol(scope, DEREF(generic))}) {
inaccessible = thisInaccessible;
} else {
result = std::move(thisResult);
hit.push_back(binding);
}
}
}
}
}
if (*definedOpSymbolPtr) {
SayNoMatch(ToUpperCase((*definedOpSymbolPtr)->name().ToString()));
} else if (actuals_.size() == 1 || AreConformable()) {
if (CheckForNullPointer()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
if (result) {
if (hit.size() > 1) {
if (auto *msg{context_.Say(
"%zd matching accessible generic interfaces for %s were found"_err_en_US,
hit.size(), ToUpperCase(opr))}) {
for (const Symbol *symbol : hit) {
AttachDeclaration(*msg, *symbol);
}
}
}
} else {
} else if (inaccessible) {
context_.Say(source_, std::move(*inaccessible));
} else if (anyPossibilities) {
SayNoMatch(ToUpperCase(oprNameString), false);
} else if (actuals_.size() == 2 && !AreConformable()) {
context_.Say(
"Operands of %s are not conformable; have rank %d and rank %d"_err_en_US,
ToUpperCase(opr), actuals_[0]->Rank(), actuals_[1]->Rank());
} else if (CheckForNullPointer()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
}
return std::nullopt;
return result;
}

MaybeExpr ArgumentAnalyzer::TryDefinedOp(
std::vector<const char *> oprs, parser::MessageFixedText error) {
const Symbol *definedOpSymbolPtr{nullptr};
for (std::size_t i{1}; i < oprs.size(); ++i) {
if (oprs.size() == 1) {
return TryDefinedOp(oprs[0], error);
}
MaybeExpr result;
std::vector<const char *> hit;
{
auto restorer{context_.GetContextualMessages().DiscardMessages()};
if (auto result{TryDefinedOp(oprs[i], error, &definedOpSymbolPtr)}) {
return result;
for (std::size_t i{0}; i < oprs.size(); ++i) {
if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error)}) {
result = std::move(thisResult);
hit.push_back(oprs[i]);
}
}
}
return TryDefinedOp(oprs[0], error, &definedOpSymbolPtr);
if (hit.empty()) { // for the error
result = TryDefinedOp(oprs[0], error);
} else if (hit.size() > 1) {
context_.Say(
"Matching accessible definitions were found with %zd variant spellings of the generic operator ('%s', '%s')"_err_en_US,
hit.size(), ToUpperCase(hit[0]), ToUpperCase(hit[1]));
}
return result;
}

MaybeExpr ArgumentAnalyzer::TryBoundOp(const Symbol &symbol, int passIndex) {
Expand Down Expand Up @@ -3768,31 +3805,34 @@ bool ArgumentAnalyzer::OkLogicalIntegerAssignment(
}

std::optional<ProcedureRef> ArgumentAnalyzer::GetDefinedAssignmentProc() {
auto restorer{context_.GetContextualMessages().DiscardMessages()};
const Symbol *proc{nullptr};
int passedObjectIndex{-1};
std::string oprNameString{"assignment(=)"};
parser::CharBlock oprName{oprNameString};
const Symbol *proc{nullptr};
const auto &scope{context_.context().FindScope(source_)};
if (const Symbol *symbol{scope.FindSymbol(oprName)}) {
ExpressionAnalyzer::AdjustActuals noAdjustment;
auto pair{context_.ResolveGeneric(*symbol, actuals_, noAdjustment, true)};
if (pair.first) {
proc = pair.first;
} else {
context_.EmitGenericResolutionError(*symbol, pair.second, true);
}
}
int passedObjectIndex{-1};
const Symbol *definedOpSymbol{nullptr};
for (std::size_t i{0}; i < actuals_.size(); ++i) {
if (const Symbol *
specific{FindBoundOp(oprName, i, definedOpSymbol, true)}) {
if (const Symbol *
resolution{GetBindingResolution(GetType(i), *specific)}) {
proc = resolution;
// If multiple resolutions were possible, they will have been already
// diagnosed.
{
auto restorer{context_.GetContextualMessages().DiscardMessages()};
if (const Symbol *symbol{scope.FindSymbol(oprName)}) {
ExpressionAnalyzer::AdjustActuals noAdjustment;
auto pair{context_.ResolveGeneric(*symbol, actuals_, noAdjustment, true)};
if (pair.first) {
proc = pair.first;
} else {
proc = specific;
passedObjectIndex = i;
context_.EmitGenericResolutionError(*symbol, pair.second, true);
}
}
for (std::size_t i{0}; i < actuals_.size(); ++i) {
const Symbol *generic{nullptr};
if (const Symbol *specific{FindBoundOp(oprName, i, generic, true)}) {
if (const Symbol *resolution{
GetBindingResolution(GetType(i), *specific)}) {
proc = resolution;
} else {
proc = specific;
passedObjectIndex = i;
}
}
}
}
Expand Down Expand Up @@ -3871,24 +3911,23 @@ bool ArgumentAnalyzer::AreConformable() const {

// Look for a type-bound operator in the type of arg number passIndex.
const Symbol *ArgumentAnalyzer::FindBoundOp(parser::CharBlock oprName,
int passIndex, const Symbol *&definedOp, bool isSubroutine) {
int passIndex, const Symbol *&generic, bool isSubroutine) {
const auto *type{GetDerivedTypeSpec(GetType(passIndex))};
if (!type || !type->scope()) {
return nullptr;
}
const Symbol *symbol{type->scope()->FindComponent(oprName)};
if (!symbol) {
generic = type->scope()->FindComponent(oprName);
if (!generic) {
return nullptr;
}
definedOp = symbol;
ExpressionAnalyzer::AdjustActuals adjustment{
[&](const Symbol &proc, ActualArguments &) {
return passIndex == GetPassIndex(proc);
}};
auto pair{
context_.ResolveGeneric(*symbol, actuals_, adjustment, isSubroutine)};
context_.ResolveGeneric(*generic, actuals_, adjustment, isSubroutine)};
if (!pair.first) {
context_.EmitGenericResolutionError(*symbol, pair.second, isSubroutine);
context_.EmitGenericResolutionError(*generic, pair.second, isSubroutine);
}
return pair.first;
}
Expand Down
5 changes: 2 additions & 3 deletions flang/lib/Semantics/resolve-names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5989,7 +5989,7 @@ bool DeclarationVisitor::OkToAddComponent(
} else if (extends) {
msg = "Type cannot be extended as it has a component named"
" '%s'"_err_en_US;
} else if (CheckAccessibleComponent(currScope(), *prev)) {
} else if (CheckAccessibleSymbol(currScope(), *prev)) {
// inaccessible component -- redeclaration is ok
msg = "Component '%s' is inaccessibly declared in or as a "
"parent of this derived type"_warn_en_US;
Expand Down Expand Up @@ -6864,8 +6864,7 @@ const parser::Name *DeclarationVisitor::FindComponent(
derived->Instantiate(currScope()); // in case of forward referenced type
if (const Scope * scope{derived->scope()}) {
if (Resolve(component, scope->FindComponent(component.source))) {
if (auto msg{
CheckAccessibleComponent(currScope(), *component.symbol)}) {
if (auto msg{CheckAccessibleSymbol(currScope(), *component.symbol)}) {
context().Say(component.source, *msg);
}
return &component;
Expand Down
5 changes: 2 additions & 3 deletions flang/lib/Semantics/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,9 +991,8 @@ bool IsPolymorphicAllocatable(const Symbol &symbol) {
return IsAllocatable(symbol) && IsPolymorphic(symbol);
}

std::optional<parser::MessageFormattedText> CheckAccessibleComponent(
std::optional<parser::MessageFormattedText> CheckAccessibleSymbol(
const Scope &scope, const Symbol &symbol) {
CHECK(symbol.owner().IsDerivedType()); // symbol must be a component
if (symbol.attrs().test(Attr::PRIVATE)) {
if (FindModuleFileContaining(scope)) {
// Don't enforce component accessibility checks in module files;
Expand All @@ -1003,7 +1002,7 @@ std::optional<parser::MessageFormattedText> CheckAccessibleComponent(
moduleScope{FindModuleContaining(symbol.owner())}) {
if (!moduleScope->Contains(scope)) {
return parser::MessageFormattedText{
"PRIVATE component '%s' is only accessible within module '%s'"_err_en_US,
"PRIVATE name '%s' is only accessible within module '%s'"_err_en_US,
symbol.name(), moduleScope->GetName().value()};
}
}
Expand Down
Loading

0 comments on commit d7a1351

Please sign in to comment.