diff --git a/frontend/lib/resolution/Resolver.cpp b/frontend/lib/resolution/Resolver.cpp index 5a130a3bbf7f..294ccc0c05a2 100644 --- a/frontend/lib/resolution/Resolver.cpp +++ b/frontend/lib/resolution/Resolver.cpp @@ -2258,6 +2258,10 @@ QualifiedType Resolver::typeForId(const ID& id, bool localGenericToUnknown) { return typeForModuleLevelSymbol(context, id, isCurrentModule); } + if (asttags::isEnum(parentTag) && asttags::isEnumElement(tag)) { + return typeForScopeResolvedEnumElement(parentId, id, /* ambiguous */ false); + } + // If the id is contained within a class/record/union that we are resolving, // get the resolved field. const CompositeType* ct = nullptr; @@ -2904,12 +2908,6 @@ void Resolver::resolveIdentifier(const Identifier* ident, llvm::ArrayRef receiverScopes) { ResolvedExpression& result = byPostorder.byAst(ident); - if (ident->name() == USTR("nil")) { - result.setType(QualifiedType(QualifiedType::CONST_VAR, - NilType::get(context))); - return; - } - // for 'proc f(arg:?)' need to set 'arg' to have type AnyType CHPL_ASSERT(declStack.size() > 0); const Decl* inDecl = declStack.back(); @@ -2986,9 +2984,27 @@ void Resolver::resolveIdentifier(const Identifier* ident, bool computeDefaults = true; bool resolvingCalledIdent = nearestCalledExpression() == ident; + // For calls like + // + // type myType = anotherType(int) + // + // Use the generic version of anotherType to feed as receiver. if (resolvingCalledIdent) { computeDefaults = false; } + + // Other special exceptions like 'r' in: + // + // proc r.init() { ... } + // + if (!genericReceiverOverrideStack.empty()) { + auto& topEntry = genericReceiverOverrideStack.back(); + if ((topEntry.first.isEmpty() || topEntry.first == ident->name()) && + topEntry.second == parsing::parentAst(context, ident)) { + computeDefaults = false; + } + } + if (computeDefaults) { type = computeTypeDefaults(*this, type); } @@ -3101,6 +3117,21 @@ bool Resolver::enter(const TypeQuery* tq) { void Resolver::exit(const TypeQuery* tq) { } +// Treat receiver types specially in terms of generic resolution. That is, +// when resolving the following initializer when r is generic with defaults, +// +// proc r.init() {} +// +// Make sure r's defaults aren't used so that the most general receiver is +// constructed. On the other hand, defaults _should_ be used for more +// complicated expressions: +// +// proc (someTypeFn(r)).init() {} +// +static bool shouldUseGenericTypeForTypeExpr(const NamedDecl* decl) { + return decl->isFormal() && decl->name() == USTR("this"); +} + bool Resolver::enter(const NamedDecl* decl) { if (decl->id().postOrderId() < 0) { @@ -3117,6 +3148,13 @@ bool Resolver::enter(const NamedDecl* decl) { enterScope(decl); + if (shouldUseGenericTypeForTypeExpr(decl)) { + // Empty string indicates that all identifiers should be treated as + // non-defaulted. Using 'decl' means that only the top-level identifiers + // will be resolved this way. + genericReceiverOverrideStack.emplace_back(UniqueString(), decl); + } + // This logic exists to prioritize the field's type expression when // resolving a field's type. If the type expression is concrete, then we // do not need to resolve the init-expression. This is beneficial in cases @@ -3146,6 +3184,10 @@ bool Resolver::enter(const NamedDecl* decl) { } void Resolver::exit(const NamedDecl* decl) { + if (shouldUseGenericTypeForTypeExpr(decl)) { + genericReceiverOverrideStack.pop_back(); + } + // We are resolving a symbol with a different path (e.g., a Function or // a CompositeType declaration). In most cases we do not try to resolve // in this traversal. However, if we are a nested function and the child @@ -3641,6 +3683,15 @@ Resolver::typeForScopeResolvedEnumElement(const EnumType* enumType, } } +QualifiedType +Resolver::typeForScopeResolvedEnumElement(const ID& enumTypeId, + const ID& refersToId, + bool ambiguous) { + auto type = initialTypeForTypeDecl(context, enumTypeId); + CHPL_ASSERT(type && type->isEnumType()); + return typeForScopeResolvedEnumElement(type->toEnumType(), refersToId, + ambiguous); +} QualifiedType Resolver::typeForEnumElement(const EnumType* enumType, UniqueString elementName, @@ -3831,6 +3882,10 @@ void Resolver::exit(const Dot* dot) { } bool Resolver::enter(const New* node) { + if (auto ident = node->typeExpression()->toIdentifier()) { + genericReceiverOverrideStack.emplace_back(ident->name(), node); + } + return true; } @@ -3912,6 +3967,10 @@ static void resolveNewForUnion(Resolver& rv, const New* node, } void Resolver::exit(const New* node) { + if (node->typeExpression()->isIdentifier()) { + genericReceiverOverrideStack.pop_back(); + } + if (scopeResolveOnly) return; diff --git a/frontend/lib/resolution/Resolver.h b/frontend/lib/resolution/Resolver.h index 552ea1d1be7e..d23b3f5b01d1 100644 --- a/frontend/lib/resolution/Resolver.h +++ b/frontend/lib/resolution/Resolver.h @@ -96,6 +96,7 @@ struct Resolver { std::set instantiatedFieldOrFormals; std::set namesWithErrorsEmitted; std::vector callNodeStack; + std::vector> genericReceiverOverrideStack; bool receiverScopesComputed = false; ReceiverScopesVec savedReceiverScopes; Resolver* parentResolver = nullptr; @@ -491,8 +492,12 @@ struct Resolver { // Given the results of looking up an enum element, construct a QualifiedType. types::QualifiedType typeForScopeResolvedEnumElement(const types::EnumType* enumType, - const ID& refersToId, - bool ambiguous); + const ID& refersToId, + bool ambiguous); + types::QualifiedType + typeForScopeResolvedEnumElement(const ID& enumTypeId, + const ID& refersToId, + bool ambiguous); // Given a particular enum type, determine the type of a particular element. types::QualifiedType typeForEnumElement(const types::EnumType* type, UniqueString elemName, diff --git a/frontend/lib/resolution/VarScopeVisitor.cpp b/frontend/lib/resolution/VarScopeVisitor.cpp index 32e4e61b81aa..e0f8a6ea04e7 100644 --- a/frontend/lib/resolution/VarScopeVisitor.cpp +++ b/frontend/lib/resolution/VarScopeVisitor.cpp @@ -387,7 +387,7 @@ void VarScopeVisitor::exitAst(const uast::AstNode* ast) { bool VarScopeVisitor::enter(const NamedDecl* ast, RV& rv) { - if (ast->id().postOrderId() < 0) { + if (ast->id().isSymbolDefiningScope()) { // It's a symbol with a different path, e.g. a Function. // Don't try to resolve it now in this // traversal. Instead, resolve it e.g. when the function is called. @@ -400,7 +400,7 @@ bool VarScopeVisitor::enter(const NamedDecl* ast, RV& rv) { return true; } void VarScopeVisitor::exit(const NamedDecl* ast, RV& rv) { - if (ast->id().postOrderId() < 0) { + if (ast->id().isSymbolDefiningScope()) { // It's a symbol with a different path, e.g. a Function. // Don't try to resolve it now in this // traversal. Instead, resolve it e.g. when the function is called. diff --git a/frontend/lib/resolution/maybe-const.cpp b/frontend/lib/resolution/maybe-const.cpp index 1417886e7435..417bd777a5db 100644 --- a/frontend/lib/resolution/maybe-const.cpp +++ b/frontend/lib/resolution/maybe-const.cpp @@ -96,6 +96,9 @@ struct AdjustMaybeRefs { bool enter(const Call* ast, RV& rv); void exit(const Call* ast, RV& rv); + bool enter(const NamedDecl* ast, RV& rv); + void exit(const NamedDecl* ast, RV& rv); + bool enter(const uast::AstNode* node, RV& rv); void exit(const uast::AstNode* node, RV& rv); }; @@ -337,6 +340,18 @@ bool AdjustMaybeRefs::enter(const Call* ast, RV& rv) { void AdjustMaybeRefs::exit(const Call* ast, RV& rv) { } +bool AdjustMaybeRefs::enter(const uast::NamedDecl* node, RV& rv) { + if (node->id().isSymbolDefiningScope()) { + // It's a symbol with a different path, e.g. a Function. + // Don't try to resolve it now in this + // traversal. Instead, resolve it e.g. when the function is called. + return false; + } + return true; +} +void AdjustMaybeRefs::exit(const uast::NamedDecl* node, RV& rv) { +} + bool AdjustMaybeRefs::enter(const uast::AstNode* node, RV& rv) { return true; } diff --git a/frontend/lib/resolution/resolution-queries.cpp b/frontend/lib/resolution/resolution-queries.cpp index e9b2eadd41d2..dbe49e2bdcf3 100644 --- a/frontend/lib/resolution/resolution-queries.cpp +++ b/frontend/lib/resolution/resolution-queries.cpp @@ -328,6 +328,12 @@ const QualifiedType& typeForBuiltin(Context* context, result = QualifiedType(QualifiedType::TYPE, t); } else if (searchGlobals != globalMap.end()) { result = searchGlobals->second; + } else if (name == USTR("nil")) { + result = QualifiedType(QualifiedType::CONST_VAR, + NilType::get(context)); + } else if (name == USTR("none")) { + result = QualifiedType(QualifiedType::CONST_VAR, + NothingType::get(context)); } else { // Could be a non-type builtin like 'index' result = QualifiedType(); diff --git a/frontend/lib/resolution/scope-queries.cpp b/frontend/lib/resolution/scope-queries.cpp index b01e84baceb2..9ffacdf79582 100644 --- a/frontend/lib/resolution/scope-queries.cpp +++ b/frontend/lib/resolution/scope-queries.cpp @@ -379,6 +379,10 @@ static void populateScopeWithBuiltins(Context* context, Scope* scope) { scope->addBuiltin(pair.first); } + // TODO: maybe we can represent these as 'NilLiteral' and 'NoneLiteral' nodes? + scope->addBuiltin(USTR("nil")); + scope->addBuiltin(USTR("none")); + populateScopeWithBuiltinKeywords(context, scope); } diff --git a/frontend/test/resolution/testMethodCalls.cpp b/frontend/test/resolution/testMethodCalls.cpp index fe2d2d4fcfe0..60f2327eb53e 100644 --- a/frontend/test/resolution/testMethodCalls.cpp +++ b/frontend/test/resolution/testMethodCalls.cpp @@ -536,6 +536,40 @@ static void test9() { assert(guard.numErrors() == 0); } +static void test10() { + // Ensure that secondary methods like 'proc x.myMethod()' are generic + // even if 'x' is generic-with-defaults. + Context ctx; + Context* context = &ctx; + ErrorGuard guard(context); + + std::string program = R"""( + record R { + type T = int; + var field : T; + } + + proc R.myMethod(): T do return this.field; + + var r1: R(int); + var r2: R(bool); + + var x1 = r1.myMethod(); + var x2 = r2.myMethod(); + )"""; + + auto vars = resolveTypesOfVariables(context, program, { "x1", "x2" }); + + auto t1 = vars.at("x1"); + assert(t1.type()); + assert(t1.type()->isIntType()); + assert(t1.type()->toIntType()->isDefaultWidth()); + + auto t2 = vars.at("x2"); + assert(t2.type()); + assert(t2.type()->isBoolType()); +} + int main() { test1(); @@ -547,6 +581,7 @@ int main() { test7(); test8(); test9(); + test10(); return 0; } diff --git a/frontend/test/resolution/testNew.cpp b/frontend/test/resolution/testNew.cpp index ac142f8d4d19..ca7ae7d3c27e 100644 --- a/frontend/test/resolution/testNew.cpp +++ b/frontend/test/resolution/testNew.cpp @@ -582,6 +582,56 @@ static void testGenericRecordUserSecondaryInitDependentField() { assert(f3.param()->toIntParam()->value() == 42); } +static void testNewGenericWithDefaults() { + Context ctx; + Context* context = &ctx; + ErrorGuard guard(context); + + auto vars = resolveTypesOfVariables(context, + R"""( + record r { + type f1 = int; + } + + proc r.init(type f1) { + this.f1 = f1; + } + + var x1 = new r(int); + var x2 = new r(bool); + )""", { "x1", "x2" }); + + + { + auto ct = vars.at("x1").type()->toCompositeType(); + assert(ct); + assert(ct->name() == "r"); + + // It should already be instantiated, no need to use defaults. + auto fields = fieldsForTypeDecl(context, ct, DefaultsPolicy::IGNORE_DEFAULTS); + assert(fields.numFields() == 1); + + auto f1 = fields.fieldType(0); + assert(f1.isType()); + assert(f1.type()->isIntType()); + assert(f1.type()->toIntType()->isDefaultWidth()); + } + + { + auto ct = vars.at("x2").type()->toCompositeType(); + assert(ct); + assert(ct->name() == "r"); + + // It should already be instantiated, no need to use defaults. + auto fields = fieldsForTypeDecl(context, ct, DefaultsPolicy::IGNORE_DEFAULTS); + assert(fields.numFields() == 1); + + auto f1 = fields.fieldType(0); + assert(f1.isType()); + assert(f1.type()->isBoolType()); + } +} + int main() { testEmptyRecordUserInit(); testEmptyRecordCompilerGenInit(); @@ -590,6 +640,7 @@ int main() { testGenericRecordUserInitDependentField(); testRecordNewSegfault(); testGenericRecordUserSecondaryInitDependentField(); + testNewGenericWithDefaults(); return 0; } diff --git a/frontend/test/resolution/testRanges.cpp b/frontend/test/resolution/testRanges.cpp index 9e6ec7be637e..924593f1361a 100644 --- a/frontend/test/resolution/testRanges.cpp +++ b/frontend/test/resolution/testRanges.cpp @@ -197,6 +197,80 @@ static void test8(Context* context) { assert(idxTypeInt->bitwidth() == 64); } +static void test9(Context* context) { + // test the count operator on a bounded range + ErrorGuard guard(context); + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); + auto qts = resolveTypesOfVariables(context, + R""""( + var y : int; + var lower: int(32); + var x1 = lower..; + var x2 = lower..#10; + )"""", {"x1", "x2"}); + + { + // Check the first range + auto qt = qts.at("x1"); + assert(qt.type() != nullptr); + auto rangeType = qt.type()->toRecordType(); + assert(rangeType != nullptr); + auto idxType = getRangeIndexType(context, rangeType, "low"); + assert(idxType.type() != nullptr); + auto idxTypeInt = idxType.type()->toIntType(); + assert(idxTypeInt->bitwidth() == 32); + } + { + // Check the counted range + auto qt = qts.at("x2"); + assert(qt.type() != nullptr); + auto rangeType = qt.type()->toRecordType(); + assert(rangeType != nullptr); + auto idxType = getRangeIndexType(context, rangeType, "both"); + assert(idxType.type() != nullptr); + auto idxTypeInt = idxType.type()->toIntType(); + assert(idxTypeInt->bitwidth() == 32); + } +} + +static void test10(Context* context) { + // test the count operator on a bounded range + ErrorGuard guard(context); + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); + auto qts = resolveTypesOfVariables(context, + R""""( + var y : int; + var higher: int(32); + var x1 = ..higher; + var x2 = ..higher#10; + )"""", {"x1", "x2"}); + + { + // Check the first range + auto qt = qts.at("x1"); + assert(qt.type() != nullptr); + auto rangeType = qt.type()->toRecordType(); + assert(rangeType != nullptr); + auto idxType = getRangeIndexType(context, rangeType, "high"); + assert(idxType.type() != nullptr); + auto idxTypeInt = idxType.type()->toIntType(); + assert(idxTypeInt->bitwidth() == 32); + } + { + // Check the counted range + auto qt = qts.at("x2"); + assert(qt.type() != nullptr); + auto rangeType = qt.type()->toRecordType(); + assert(rangeType != nullptr); + auto idxType = getRangeIndexType(context, rangeType, "both"); + assert(idxType.type() != nullptr); + auto idxTypeInt = idxType.type()->toIntType(); + assert(idxTypeInt->bitwidth() == 32); + } +} + int main() { // first test runs without environment and stdlib. test1(); @@ -212,5 +286,7 @@ int main() { test6(ctx); test7(ctx); test8(ctx); + test9(ctx); + test10(ctx); return 0; } diff --git a/frontend/test/resolution/testResolve.cpp b/frontend/test/resolution/testResolve.cpp index 6277dd114e51..fc516bfa873c 100644 --- a/frontend/test/resolution/testResolve.cpp +++ b/frontend/test/resolution/testResolve.cpp @@ -1538,6 +1538,23 @@ static void test24() { } } +static void test25() { + // Test that 'none' has type 'nothing' + Context ctx; + Context* context = &ctx; + ErrorGuard guard(context); + + { + // straightforward case for qualified module + std::string prog = "var x = none;"; + + auto t = resolveTypeOfXInit(context, prog); + assert(t.type()); + assert(t.type()->isNothingType()); + assert(guard.realizeErrors() == 0); + } +} + int main() { test1(); test2(); @@ -1563,6 +1580,7 @@ int main() { test22(); test23(); test24(); + test25(); return 0; } diff --git a/modules/internal/ChapelRange.chpl b/modules/internal/ChapelRange.chpl index 86310da43ff1..6d31de5005fa 100644 --- a/modules/internal/ChapelRange.chpl +++ b/modules/internal/ChapelRange.chpl @@ -2348,12 +2348,18 @@ private proc isBCPindex(type t) param do if isPositiveStride(newStrides, st) then // start from the low index return if hasLowBoundForIter(r) - then newAlignedRange(r.chpl_alignedLowAsIntForIter) + // inlined: newAlignedRange(r.chpl_alignedLowAsIntForIter) + // because Dyno can't helper capturing nested functions. + then new range(i, b, newStrides, lw, hh, st, + r.chpl_alignedLowAsIntForIter, true, true) else if st == 1 then newZeroAlmtRange() else newUnalignedRange(); else // start from the high index return if hasHighBoundForIter(r) - then newAlignedRange(r.chpl_alignedHighAsIntForIter) + // inlined: newAlignedRange(r.chpl_alignedHighAsIntForIter) + // because Dyno can't helper capturing nested functions. + then new range(i, b, newStrides, lw, hh, st, + r.chpl_alignedHighAsIntForIter, true, true) else if st == -1 then newZeroAlmtRange() else newUnalignedRange(); proc newAlignedRange(alignment) do @@ -2723,7 +2729,7 @@ private proc isBCPindex(type t) param do type resultType = r.chpl_integralIdxType; type strType = chpl__rangeStrideType(resultType); - proc absSameType() { + proc absSameType(r, type resultType) { if r.hasNegativeStride() { return (-r.stride):resultType; } else { @@ -2737,14 +2743,14 @@ private proc isBCPindex(type t) param do bounds = boundKind.both, strides = r.strides, _low = r._low, - _high = r._low - absSameType(), + _high = r._low - absSameType(r, resultType), _stride = r.stride, alignmentValue = r._alignment); } else if (r.hasHighBound()) { return new range(idxType = r.idxType, bounds = boundKind.both, strides = r.strides, - _low = r._high + absSameType(), + _low = r._high + absSameType(r, resultType), _high = r._high, _stride = r.stride, alignmentValue = r._alignment);