diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 312283b05..aa0900140 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -97,6 +97,12 @@ struct ApplyTypeFunction : Substitution TypePackId clean(TypePackId tp) override; }; +struct GenericTypeDefinitions +{ + std::vector genericTypes; + std::vector genericPacks; +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -146,7 +152,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -336,8 +342,8 @@ struct TypeChecker const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d6e177142..fd2c2afa7 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -181,6 +181,18 @@ const T* get(const SingletonTypeVar* stv) return nullptr; } +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; +}; + struct FunctionArgument { Name name; @@ -358,8 +370,8 @@ struct ClassTypeVar struct TypeFun { // These should all be generic - std::vector typeParams; - std::vector typePackParams; + std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -369,13 +381,13 @@ struct TypeFun TypeId type; TypeFun() = default; - TypeFun(std::vector typeParams, TypeId type) + TypeFun(std::vector typeParams, TypeId type) : typeParams(std::move(typeParams)) , type(type) { } - TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) : typeParams(std::move(typeParams)) , typePackParams(std::move(typePackParams)) , type(type) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 67ebd0755..7a801f970 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,11 +13,10 @@ #include LUAU_FASTFLAG(LuauUseCommittingTxnLog) -LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); +LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -291,51 +290,23 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ expectedType = follow(*it); } - if (FFlag::LuauAutocompletePreferToCallFunctions) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - } - } - - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - } - else - { - if (canUnify(ty, expectedType)) - return TypeCorrectKind::Correct; - - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); - - if (!ftv) - return TypeCorrectKind::None; - auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty()) - return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; } - - return TypeCorrectKind::None; } + + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -435,13 +406,28 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) + if (FFlag::LuauMissingFollowACMetatables) { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } + } + else + { + if (get(indexIt->second.type) || get(indexIt->second.type)) + autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); + else if (auto indexFunction = get(indexIt->second.type)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } } } } @@ -1224,7 +1210,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (auto it = module.astTypes.find(node->asExpr())) autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } - else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) + else if (autocompleteIfElseExpression(node, ancestry, position, result)) return; else if (node->is()) return; @@ -1261,8 +1247,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - if (FFlag::LuauIfElseExpressionAnalysisSupport) - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index ce832c6b3..88069f1f5 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -190,24 +190,24 @@ struct ErrorConverter { name += "<"; bool first = true; - for (TypeId t : e.typeFun.typeParams) + for (auto param : e.typeFun.typeParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.ty); } - for (TypePackId t : e.typeFun.typePackParams) + for (auto param : e.typeFun.typePackParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.tp); } name += ">"; @@ -544,13 +544,13 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { - if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) + if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty) return false; } for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp) return false; } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 5bc76ade5..19c2ddabd 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -96,24 +96,24 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "<"; bool first = true; - for (TypeId t : error.typeFun.typeParams) + for (auto param : error.typeFun.typeParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.ty); } - for (TypePackId t : error.typeFun.typePackParams) + for (auto param : error.typeFun.typePackParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.tp); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 23491a5a1..8dd597e17 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,6 +5,8 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace Luau { @@ -337,6 +339,42 @@ struct AstJsonEncoder : public AstVisitor writeRaw("}"); } + void write(const AstGenericType& genericType) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericType.name); + if (genericType.defaultValue) + write("type", genericType.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericType.name); + } + } + + void write(const AstGenericTypePack& genericTypePack) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("type", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericTypePack.name); + } + } + void write(AstExprTable::Item::Kind kind) { switch (kind) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index cff85897c..9f352f4b3 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypeAliasDefaults) namespace Luau { @@ -447,11 +448,28 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; - for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 889dd6dc5..4b898d3a6 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) +LUAU_FASTFLAG(LuauTypeAliasDefaults) /* * Prefix generic typenames with gen- @@ -209,6 +210,14 @@ struct StringifierState result.name += s; } + + void emit(const char* s) + { + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + return; + + result.name += s; + } }; struct TypeVarStringifier @@ -280,13 +289,28 @@ struct TypeVarStringifier else first = false; - if (!singleTp) - state.emit("("); + if (FFlag::LuauTypeAliasDefaults) + { + bool wrap = !singleTp && get(follow(tp)); - stringify(tp); + if (wrap) + state.emit("("); - if (!singleTp) - state.emit(")"); + stringify(tp); + + if (wrap) + state.emit(")"); + } + else + { + if (!singleTp) + state.emit("("); + + stringify(tp); + + if (!singleTp) + state.emit(")"); + } } if (types.size() || typePacks.size()) @@ -1086,7 +1110,7 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) { std::string s = prefix; @@ -1175,6 +1199,77 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + if (!FFlag::LuauTypeAliasDefaults) + return toStringNamedFunction_DEPRECATED(prefix, ftv, opts); + + ToStringResult result; + StringifierState state(opts, result, opts.nameMap); + TypeVarStringifier tvs{state}; + + state.emit(prefix); + + if (!opts.hideNamedFunctionTypeParameters) + tvs.stringify(ftv.generics, ftv.genericPacks); + + state.emit("("); + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + state.emit(", "); + first = false; + + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); + ++argNameIter; + } + else + { + state.emit("_: "); + } + + tvs.stringify(*argPackIter); + ++argPackIter; + } + + if (argPackIter.tail()) + { + if (!first) + state.emit(", "); + + state.emit("...: "); + + if (auto vtp = get(*argPackIter.tail())) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } + + state.emit("): "); + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + + if (wrap) + state.emit("("); + + tvs.stringify(ftv.retType); + + if (wrap) + state.emit(")"); + + return result.name; +} + std::string dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 8e13ea5be..f59086834 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace { bool isIdentifierStartChar(char c) @@ -793,14 +795,47 @@ struct Printer for (auto o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } + } + else + { + writer.identifier(o.name.value); + } } for (auto o : a->genericPacks) { comma(); - writer.identifier(o.value); - writer.symbol("..."); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + else + { + writer.identifier(o.name.value); + writer.symbol("..."); + } } writer.symbol(">"); @@ -846,12 +881,20 @@ struct Printer for (const auto& o : func.generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); @@ -979,12 +1022,20 @@ struct Printer for (const auto& o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 9e61c7924..2ec020937 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -212,24 +212,24 @@ class TypeRehydrationVisitor if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName("")); - AstArray generics; + AstArray generics; generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { if (auto gtv = get(*it)) - generics.data[numGenerics++] = AstName(gtv->name.c_str()); + generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } - AstArray genericPacks; + AstArray genericPacks; genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { if (auto gtv = get(*it)) - genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); + genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1689a5c3d..bedcc0227 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -15,8 +15,8 @@ #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" -#include #include +#include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) @@ -24,25 +24,30 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau @@ -279,6 +284,14 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); } + if (FFlag::LuauPerModuleUnificationCache) + { + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.skipCacheForType.clear(); + } + return std::move(currentModule); } @@ -1213,18 +1226,18 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); - for (TypeId ty : binding->typeParams) + for (auto param : binding->typeParams) { - auto generic = get(ty); + auto generic = get(param.ty); LUAU_ASSERT(generic); - aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; } - for (TypePackId tp : binding->typePackParams) + for (auto param : binding->typePackParams) { - auto generic = get(tp); + auto generic = get(param.tp); LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; + aliasScope->privateTypePackBindings[generic->name] = param.tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1233,9 +1246,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the table is already named and we want to rename the type function, we have to bind new alias to a copy if (ttv->name) { + bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), + binding->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), + binding->typePackParams.begin(), binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // Copy can be skipped if this is an identical alias - if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - ttv->instantiatedTypePackParams != binding->typePackParams) + if (ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1243,8 +1264,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; - clone.instantiatedTypeParams = binding->typeParams; - clone.instantiatedTypePackParams = binding->typePackParams; + + for (auto param : binding->typeParams) + clone.instantiatedTypeParams.push_back(param.ty); + + for (auto param : binding->typePackParams) + clone.instantiatedTypePackParams.push_back(param.tp); ty = addType(std::move(clone)); } @@ -1252,8 +1277,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else { ttv->name = name; - ttv->instantiatedTypeParams = binding->typeParams; - ttv->instantiatedTypePackParams = binding->typePackParams; + + ttv->instantiatedTypeParams.clear(); + for (auto param : binding->typeParams) + ttv->instantiatedTypeParams.push_back(param.ty); + + ttv->instantiatedTypePackParams.clear(); + for (auto param : binding->typePackParams) + ttv->instantiatedTypePackParams.push_back(param.tp); } } else if (auto mtv = getMutable(follow(ty))) @@ -1367,9 +1398,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); + TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); FunctionTypeVar* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -1394,7 +1437,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& ExprResult result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr); + result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt); else if (expr.is()) result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) @@ -1438,21 +1481,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - { - if (FFlag::LuauIfElseExpressionAnalysisSupport) - { - result = checkExpr(scope, *a); - } - else - { - // Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse - // because we would generate an ICE. We also can't use the default value - // of result, because it will lead to a compiler crash. - // Note: LuauIfElseExpressionBaseSupport can be used to disable parser support - // for if-else expressions which will mean this node type is never created. - result = {anyType}; - } - } + result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt); else ice("Unhandled AstExpr?"); @@ -1895,7 +1924,7 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; + TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); @@ -2549,23 +2578,34 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { ExprResult result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); reportErrors(resolve(result.predicates, trueScope, true)); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr); + ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); + ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - unify(falseType.type, trueType.type, expr.location); + if (FFlag::LuauIfElseBranchTypeUnion) + { + if (falseType.type == trueType.type) + return {trueType.type}; + + std::vector types = reduceUnion({trueType.type, falseType.type}); + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; + } + else + { + unify(falseType.type, trueType.type, expr.location); - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. - return {trueType.type}; + // TODO: normalize(UnionTypeVar{{trueType, falseType}}) + // For now both trueType and falseType must be the same type. + return {trueType.type}; + } } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) @@ -3032,7 +3072,20 @@ std::pair TypeChecker::checkFunctionSignature( defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); - TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId funTy = + addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); FunctionTypeVar* ftv = getMutable(funTy); @@ -4848,11 +4901,38 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - if (!lit->hasParameterList && !tf->typePackParams.empty()) + bool hasDefaultTypes = false; + bool hasDefaultPacks = false; + bool parameterCountErrorReported = false; + + if (FFlag::LuauTypeAliasDefaults) { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); + hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!lit->hasParameterList) + { + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } + } + } + else + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } } std::vector typeParams; @@ -4892,14 +4972,89 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); + if (FFlag::LuauTypeAliasDefaults) + { + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); + + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); + + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) + { + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.typeArguments.clear(); + applyTypeFunction.typePackArguments.clear(); + applyTypeFunction.currentModule = currentModule; + applyTypeFunction.level = scope->level; + applyTypeFunction.encounteredForwardedType = false; + + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; + + if (typesProvided < typesRequired) + { + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); + + if (!defaultTy) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); + } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); + } + } + + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); + + if (!defaultTp) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); + } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); + } + } + } + } + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) typePackParams.push_back(addTypePack({})); if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + if (!parameterCountErrorReported) + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { @@ -4913,11 +5068,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + if (FFlag::LuauRecursiveTypeParameterRestriction) { + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // If the generic parameters and the type arguments are the same, we are about to // perform an identity substitution, which we can just short-circuit. - return tf->type; + if (sameTys && sameTps) + return tf->type; } return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); @@ -4948,7 +5112,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); - TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); FunctionTypeVar* ftv = getMutable(fnType); @@ -5137,11 +5313,11 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; applyTypeFunction.typePackArguments.clear(); for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -5213,17 +5389,23 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; - std::vector generics; - for (const AstName& generic : genericNames) + std::vector generics; + + for (const AstGenericType& generic : genericNames) { - Name n = generic.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && generic.defaultValue) + defaultValue = resolveType(scope, *generic.defaultValue); + + Name n = generic.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5246,14 +5428,20 @@ std::pair, std::vector> TypeChecker::createGener g = addType(Unifiable::Generic{level, n}); } - generics.push_back(g); + generics.push_back({g, defaultValue}); scope->privateTypeBindings[n] = TypeFun{{}, g}; } - std::vector genericPacks; - for (const AstName& genericPack : genericPackNames) + std::vector genericPacks; + + for (const AstGenericTypePack& genericPack : genericPackNames) { - Name n = genericPack.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue) + defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + + Name n = genericPack.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5276,7 +5464,7 @@ std::pair, std::vector> TypeChecker::createGener g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } - genericPacks.push_back(g); + genericPacks.push_back({g, defaultValue}); scope->privateTypePackBindings[n] = g; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4cab79c8a..ac2b25410 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(DebugLuauFreezeArena) @@ -453,6 +454,9 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { + if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs)) + return true; + return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 393a84a70..6873c657c 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) @@ -1170,9 +1171,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); @@ -1370,9 +1377,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 5b4bfa033..573850a5a 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -334,6 +334,20 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName +struct AstGenericType +{ + AstName name; + Location location; + AstType* defaultValue = nullptr; +}; + +struct AstGenericTypePack +{ + AstName name; + Location location; + AstTypePack* defaultValue = nullptr; +}; + extern int gAstRttiIndex; template @@ -569,15 +583,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, + AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, std::optional argLocation = std::nullopt); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; bool hasReturnAnnotation; @@ -942,14 +956,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -972,14 +986,15 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; @@ -1077,13 +1092,13 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, - const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 87ebc48b5..40ecdcdd5 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -219,7 +219,7 @@ class Parser AstTableIndexer* parseTableIndexerAnnotation(); AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); @@ -281,7 +281,7 @@ class Parser Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(); + std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); @@ -418,6 +418,8 @@ class Parser std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; std::string scratchData; }; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index e709894d9..9b5bc0c71 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -158,9 +158,10 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, + std::optional argLocation) : AstExpr(ClassIndex(), location) , generics(generics) , genericPacks(genericPacks) @@ -641,8 +642,8 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -655,7 +656,21 @@ AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) + { + for (const AstGenericType& el : generics) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + + for (const AstGenericTypePack& el : genericPacks) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + type->visit(visitor); + } } AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) @@ -671,8 +686,9 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -778,7 +794,7 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) , generics(generics) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 72f616497..77787cb1c 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,10 +10,10 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) +LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) namespace Luau { @@ -394,23 +394,13 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauIfStatementRecursionGuard) - { - unsigned int recursionCounterOld = recursionCounter; - incrementRecursionCounter("elseif"); - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; - } - else - { - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - } + unsigned int recursionCounterOld = recursionCounter; + incrementRecursionCounter("elseif"); + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + recursionCounter = recursionCounterOld; } else { @@ -772,7 +762,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults); expectAndConsume('=', "type alias"); @@ -788,8 +778,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; @@ -849,7 +839,7 @@ AstStat* Parser::parseDeclaration(const Location& start) nextLexeme(); Name globalName = parseName("global function name"); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); @@ -991,7 +981,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1228,8 +1218,8 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; + AstArray generics{nullptr, 0}; + AstArray genericPacks{nullptr, 0}; AstArray types = copy(result); AstArray> names = copy(resultNames); @@ -1363,7 +1353,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme begin = lexer.current(); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme parameterStart = lexer.current(); @@ -1401,7 +1391,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, +AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) { @@ -1448,7 +1438,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1461,7 +1451,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } else @@ -1498,7 +1488,7 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() TempVector parts(scratchAnnotation); - auto [type, typePack] = parseSimpleTypeAnnotation(true); + auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); if (typePack) { @@ -1521,7 +1511,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); recursionCounter = oldRecursionCount; @@ -2121,7 +2111,7 @@ AstExpr* Parser::parseSimpleExpr() { return parseTableConstructor(); } - else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) + else if (lexer.current().type == Lexeme::ReservedIf) { return parseIfElseExpr(); } @@ -2341,10 +2331,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeList() +std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) { - TempVector names{scratchName}; - TempVector namePacks{scratchPackName}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { @@ -2352,21 +2342,73 @@ std::pair, AstArray> Parser::parseGenericTypeList() nextLexeme(); bool seenPack = false; + bool seenDefault = false; + while (true) { + Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack)) { seenPack = true; - nextLexeme(); - namePacks.push_back(name); + + if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3) + report(lexer.current().location, "Generic types come before generic type packs"); + else + nextLexeme(); + + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + Lexeme packBegin = lexer.current(); + + if (shouldParseTypePackAnnotation(lexer)) + { + auto typePack = parseTypePackAnnotation(); + + namePacks.push_back({name, nameLocation, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type pack after type pack name"); + + namePacks.push_back({name, nameLocation, nullptr}); + } } else { - if (seenPack) + if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack) report(lexer.current().location, "Generic types come before generic type packs"); - names.push_back(name); + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + AstType* defaultType = parseTypeAnnotation(); + + names.push_back({name, nameLocation, defaultType}); + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type after type name"); + + names.push_back({name, nameLocation, nullptr}); + } } if (lexer.current().type == ',') @@ -2378,8 +2420,8 @@ std::pair, AstArray> Parser::parseGenericTypeList() expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 1375f4140..bf04281c6 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,8 @@ #include "FileUtils.h" +LUAU_FASTFLAG(DebugLuauTimeTracing) + enum class ReportFormat { Default, @@ -105,6 +107,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -213,7 +216,17 @@ int main(int argc, char** argv) format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; + else if (strcmp(argv[i], "--timetrace") == 0) + FFlag::DebugLuauTimeTracing.value = true; + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; } +#endif Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; @@ -240,5 +253,8 @@ int main(int argc, char** argv) fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); } - return (format == ReportFormat::Luacheck) ? 0 : failed; + if (format == ReportFormat::Luacheck) + return 0; + else + return failed ? 1 : 0; } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 35fbf309d..cf958939f 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -19,6 +19,16 @@ #include #endif +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CliMode +{ + Unknown, + Repl, + Compile, + RunSourceFiles +}; + enum class CompileFormat { Text, @@ -485,8 +495,10 @@ static void displayHelp(const char* argv0) printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -503,71 +515,112 @@ int main(int argc, char** argv) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = true; - if (argc == 1) + CliMode mode = CliMode::Unknown; + CompileFormat compileFormat{}; + int profile = 0; + bool coverage = false; + + // Set the mode if the user has explicitly specified one. + int argStart = 1; + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - runRepl(); - return 0; + argStart++; + mode = CliMode::Compile; + if (strcmp(argv[1], "--compile") == 0) + { + compileFormat = CompileFormat::Text; + } + else if (strcmp(argv[1], "--compile=binary") == 0) + { + compileFormat = CompileFormat::Binary; + } + else if (strcmp(argv[1], "--compile=text") == 0) + { + compileFormat = CompileFormat::Text; + } + else + { + fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n"); + return -1; + } } - if (argc >= 2 && strcmp(argv[1], "--help") == 0) + for (int i = argStart; i < argc; i++) { - displayHelp(argv[0]); - return 0; - } + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strcmp(argv[i], "--profile") == 0) + { + profile = 10000; // default to 10 KHz + } + else if (strncmp(argv[i], "--profile=", 10) == 0) + { + profile = atoi(argv[i] + 10); + } + else if (strcmp(argv[i], "--coverage") == 0) + { + coverage = true; + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; +#if !defined(LUAU_ENABLE_TIME_TRACE) + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; +#endif + } + else if (argv[i][0] == '-') + { + fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } - if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) + const std::vector files = getSourceFiles(argc, argv); + if (mode == CliMode::Unknown) { - CompileFormat format = CompileFormat::Text; - - if (strcmp(argv[1], "--compile=binary") == 0) - format = CompileFormat::Binary; + mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; + } + switch (mode) + { + case CliMode::Compile: + { #ifdef _WIN32 - if (format == CompileFormat::Binary) + if (compileFormat == CompileFormat::Binary) _setmode(_fileno(stdout), _O_BINARY); #endif - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), format); + failed += !compileFile(path.c_str(), compileFormat); - return failed; + return failed ? 1 : 0; } - + case CliMode::Repl: + { + runRepl(); + return 0; + } + case CliMode::RunSourceFiles: { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); setupState(L); - int profile = 0; - bool coverage = false; - - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] != '-') - continue; - - if (strcmp(argv[i], "--profile") == 0) - profile = 10000; // default to 10 KHz - else if (strncmp(argv[i], "--profile=", 10) == 0) - profile = atoi(argv[i] + 10); - else if (strcmp(argv[i], "--coverage") == 0) - coverage = true; - } - if (profile) profilerStart(L, profile); if (coverage) coverageInit(L); - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) @@ -582,6 +635,10 @@ int main(int argc, char** argv) if (coverage) coverageDump("coverage.out"); - return failed; + return failed ? 1 : 0; + } + case CliMode::Unknown: + default: + LUAU_ASSERT(!"Unhandled cli mode."); } } diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp new file mode 100644 index 000000000..e344eb917 --- /dev/null +++ b/Compiler/src/Builtins.cpp @@ -0,0 +1,197 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Builtins.h" + +#include "Luau/Bytecode.h" +#include "Luau/Compiler.h" + +namespace Luau +{ +namespace Compile +{ + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables) +{ + if (AstExprLocal* expr = node->as()) + { + const Variable* v = variables.find(expr->local); + + return v && !v->written && v->init ? getBuiltin(v->init, globals, variables) : Builtin(); + } + else if (AstExprIndexName* expr = node->as()) + { + if (AstExprGlobal* object = expr->expr->as()) + { + return getGlobalState(globals, object->name) == Global::Default ? Builtin{object->name, expr->index} : Builtin(); + } + else + { + return Builtin(); + } + } + else if (AstExprGlobal* expr = node->as()) + { + return getGlobalState(globals, expr->name) == Global::Default ? Builtin{AstName(), expr->name} : Builtin(); + } + else + { + return Builtin(); + } +} + +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) +{ + if (builtin.empty()) + return -1; + + if (builtin.isGlobal("assert")) + return LBF_ASSERT; + + if (builtin.isGlobal("type")) + return LBF_TYPE; + + if (builtin.isGlobal("typeof")) + return LBF_TYPEOF; + + if (builtin.isGlobal("rawset")) + return LBF_RAWSET; + if (builtin.isGlobal("rawget")) + return LBF_RAWGET; + if (builtin.isGlobal("rawequal")) + return LBF_RAWEQUAL; + + if (builtin.isGlobal("unpack")) + return LBF_TABLE_UNPACK; + + if (builtin.object == "math") + { + if (builtin.method == "abs") + return LBF_MATH_ABS; + if (builtin.method == "acos") + return LBF_MATH_ACOS; + if (builtin.method == "asin") + return LBF_MATH_ASIN; + if (builtin.method == "atan2") + return LBF_MATH_ATAN2; + if (builtin.method == "atan") + return LBF_MATH_ATAN; + if (builtin.method == "ceil") + return LBF_MATH_CEIL; + if (builtin.method == "cosh") + return LBF_MATH_COSH; + if (builtin.method == "cos") + return LBF_MATH_COS; + if (builtin.method == "deg") + return LBF_MATH_DEG; + if (builtin.method == "exp") + return LBF_MATH_EXP; + if (builtin.method == "floor") + return LBF_MATH_FLOOR; + if (builtin.method == "fmod") + return LBF_MATH_FMOD; + if (builtin.method == "frexp") + return LBF_MATH_FREXP; + if (builtin.method == "ldexp") + return LBF_MATH_LDEXP; + if (builtin.method == "log10") + return LBF_MATH_LOG10; + if (builtin.method == "log") + return LBF_MATH_LOG; + if (builtin.method == "max") + return LBF_MATH_MAX; + if (builtin.method == "min") + return LBF_MATH_MIN; + if (builtin.method == "modf") + return LBF_MATH_MODF; + if (builtin.method == "pow") + return LBF_MATH_POW; + if (builtin.method == "rad") + return LBF_MATH_RAD; + if (builtin.method == "sinh") + return LBF_MATH_SINH; + if (builtin.method == "sin") + return LBF_MATH_SIN; + if (builtin.method == "sqrt") + return LBF_MATH_SQRT; + if (builtin.method == "tanh") + return LBF_MATH_TANH; + if (builtin.method == "tan") + return LBF_MATH_TAN; + if (builtin.method == "clamp") + return LBF_MATH_CLAMP; + if (builtin.method == "sign") + return LBF_MATH_SIGN; + if (builtin.method == "round") + return LBF_MATH_ROUND; + } + + if (builtin.object == "bit32") + { + if (builtin.method == "arshift") + return LBF_BIT32_ARSHIFT; + if (builtin.method == "band") + return LBF_BIT32_BAND; + if (builtin.method == "bnot") + return LBF_BIT32_BNOT; + if (builtin.method == "bor") + return LBF_BIT32_BOR; + if (builtin.method == "bxor") + return LBF_BIT32_BXOR; + if (builtin.method == "btest") + return LBF_BIT32_BTEST; + if (builtin.method == "extract") + return LBF_BIT32_EXTRACT; + if (builtin.method == "lrotate") + return LBF_BIT32_LROTATE; + if (builtin.method == "lshift") + return LBF_BIT32_LSHIFT; + if (builtin.method == "replace") + return LBF_BIT32_REPLACE; + if (builtin.method == "rrotate") + return LBF_BIT32_RROTATE; + if (builtin.method == "rshift") + return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz") + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz") + return LBF_BIT32_COUNTRZ; + } + + if (builtin.object == "string") + { + if (builtin.method == "byte") + return LBF_STRING_BYTE; + if (builtin.method == "char") + return LBF_STRING_CHAR; + if (builtin.method == "len") + return LBF_STRING_LEN; + if (builtin.method == "sub") + return LBF_STRING_SUB; + } + + if (builtin.object == "table") + { + if (builtin.method == "insert") + return LBF_TABLE_INSERT; + if (builtin.method == "unpack") + return LBF_TABLE_UNPACK; + } + + if (options.vectorCtor) + { + if (options.vectorLib) + { + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) + return LBF_VECTOR; + } + else + { + if (builtin.isGlobal(options.vectorCtor)) + return LBF_VECTOR; + } + } + + return -1; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h new file mode 100644 index 000000000..60df53a18 --- /dev/null +++ b/Compiler/src/Builtins.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +struct CompileOptions; +} + +namespace Luau +{ +namespace Compile +{ + +struct Builtin +{ + AstName object; + AstName method; + + bool empty() const + { + return object == AstName() && method == AstName(); + } + + bool isGlobal(const char* name) const + { + return object == AstName() && method == name; + } + + bool isMethod(const char* table, const char* name) const + { + return object == table && method == name; + } +}; + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables); +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 0156fcf5a..9758c4a9a 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -6,15 +6,20 @@ #include "Luau/Common.h" #include "Luau/TimeTrace.h" +#include "Builtins.h" +#include "ConstantFolding.h" +#include "TableShape.h" +#include "ValueTracking.h" + #include #include #include -LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) - namespace Luau { +using namespace Luau::Compile; + static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; @@ -62,7 +67,6 @@ static BytecodeBuilder::StringRef sref(AstArray data) struct Compiler { - struct Constant; struct RegScope; Compiler(BytecodeBuilder& bytecode, const CompileOptions& options) @@ -71,8 +75,9 @@ struct Compiler , functions(nullptr) , locals(nullptr) , globals(AstName()) + , variables(nullptr) , constants(nullptr) - , predictedTableSize(nullptr) + , tableShapes(nullptr) { } @@ -96,8 +101,10 @@ struct Compiler local->location, "Out of upvalue registers when trying to allocate %s: exceeded limit %d", local->name.value, kMaxUpvalueCount); // mark local as captured so that closeLocals emits LOP_CLOSEUPVALS accordingly - Local& l = locals[local]; - l.captured = true; + Variable* v = variables.find(local); + + if (v && v->written) + locals[local].captured = true; upvals.push_back(local); @@ -273,8 +280,8 @@ struct Compiler if (options.optimizationLevel >= 1) { - Builtin builtin = getBuiltin(expr->func); - bfid = getBuiltinFunctionId(builtin); + Builtin builtin = getBuiltin(expr->func, globals, variables); + bfid = getBuiltinFunctionId(builtin, options); } if (expr->self) @@ -424,8 +431,10 @@ struct Compiler for (AstLocal* uv : f->upvals) { - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); + Variable* ul = variables.find(uv); + + if (!ul) + return false; if (ul->written) return false; @@ -437,10 +446,11 @@ struct Compiler // this will only deoptimize (outside of fenv changes) if top level code is executed twice with different results. if (uv->functionDepth != 0 || uv->loopDepth != 0) { - if (!ul->func) + AstExprFunction* uf = ul->init ? ul->init->as() : nullptr; + if (!uf) return false; - if (ul->func != func && !shouldShareClosure(ul->func)) + if (uf != func && !shouldShareClosure(uf)) return false; } } @@ -483,10 +493,8 @@ struct Compiler { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); - - bool immutable = !ul->written; + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; if (uv->functionDepth == expr->functionDepth - 1) { @@ -687,7 +695,7 @@ struct Compiler break; case Constant::Type_String: - cid = bytecode.addConstantString(sref(c->valueString)); + cid = bytecode.addConstantString(sref(c->getString())); break; default: @@ -1066,10 +1074,10 @@ struct Compiler // Optimization: if the table is empty, we can compute it directly into the target if (expr->items.size == 0) { - auto [hashSize, arraySize] = predictedTableSize[expr]; + TableShape shape = tableShapes[expr]; - bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(hashSize), 0); - bytecode.emitAux(arraySize); + bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(shape.hashSize), 0); + bytecode.emitAux(shape.arraySize); return; } @@ -1252,16 +1260,12 @@ struct Compiler bool canImport(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || !global->written); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) != Global::Written; } bool canImportChain(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) == Global::Default; } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -1341,7 +1345,7 @@ struct Compiler { uint8_t rt = compileExprAuto(expr->expr, rs); - BytecodeBuilder::StringRef iname = sref(cv->valueString); + BytecodeBuilder::StringRef iname = sref(cv->getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); @@ -1427,7 +1431,7 @@ struct Compiler case Constant::Type_String: { - int32_t cid = bytecode.addConstantString(sref(cv->valueString)); + int32_t cid = bytecode.addConstantString(sref(cv->getString())); if (cid < 0) CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); @@ -1546,7 +1550,7 @@ struct Compiler { compileExpr(expr->expr, target, targetTemp); } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) + else if (AstExprIfElse* expr = node->as()) { compileExprIfElse(expr, target, targetTemp); } @@ -1711,7 +1715,7 @@ struct Compiler { LValue result = {LValue::Kind_IndexName}; result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->valueString); + result.name = sref(cv->getString()); result.location = node->location; return result; @@ -1796,9 +1800,8 @@ struct Compiler return false; Local* l = locals.find(le->local); - LUAU_ASSERT(l); - return l->allocated; + return l && l->allocated; } bool isStatBreak(AstStat* node) @@ -2040,9 +2043,9 @@ struct Compiler for (AstLocal* local : stat->vars) { - Local* l = locals.find(local); + Variable* v = variables.find(local); - if (!l || l->constant.type == Constant::Type_Unknown) + if (!v || !v->constant) return false; } @@ -2082,9 +2085,7 @@ struct Compiler // through) uint8_t varreg = regs + 2; - Local* il = locals.find(stat->var); - - if (il && il->written) + if (Variable* il = variables.find(stat->var); il && il->written) varreg = allocReg(stat, 1); compileExprTemp(stat->from, uint8_t(regs + 2)); @@ -2164,7 +2165,7 @@ struct Compiler { if (stat->values.size == 1 && stat->values.data[0]->is()) { - Builtin builtin = getBuiltin(stat->values.data[0]->as()->func); + Builtin builtin = getBuiltin(stat->values.data[0]->as()->func, globals, variables); if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { @@ -2179,7 +2180,7 @@ struct Compiler } else if (stat->values.size == 2) { - Builtin builtin = getBuiltin(stat->values.data[0]); + Builtin builtin = getBuiltin(stat->values.data[0], globals, variables); if (builtin.isGlobal("next")) // for .. in next,t { @@ -2594,7 +2595,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) return true; } @@ -2613,7 +2614,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) { captured = true; captureReg = std::min(captureReg, l->reg); @@ -2728,519 +2729,6 @@ struct Compiler return !node->is() && !node->is(); } - struct AssignmentVisitor : AstVisitor - { - struct Hasher - { - size_t operator()(const std::pair& p) const - { - return std::hash()(p.first) ^ std::hash()(p.second); - } - }; - - DenseHashMap localToTable; - DenseHashSet, Hasher> fields; - - AssignmentVisitor(Compiler* self) - : localToTable(nullptr) - , fields(std::pair()) - , self(self) - { - } - - void assignField(AstExpr* expr, AstName index) - { - if (AstExprLocal* lv = expr->as()) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - std::pair field = {*table, index}; - - if (!fields.contains(field)) - { - fields.insert(field); - self->predictedTableSize[*table].first += 1; - } - } - } - } - - void assignField(AstExpr* expr, AstExpr* index) - { - AstExprLocal* lv = expr->as(); - AstExprConstantNumber* number = index->as(); - - if (lv && number) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - unsigned int& arraySize = self->predictedTableSize[*table].second; - - if (number->value == double(arraySize + 1)) - arraySize += 1; - } - } - } - - void assign(AstExpr* var) - { - if (AstExprLocal* lv = var->as()) - { - self->locals[lv->local].written = true; - } - else if (AstExprGlobal* gv = var->as()) - { - self->globals[gv->name].written = true; - } - else if (AstExprIndexName* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else if (AstExprIndexExpr* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else - { - // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 - var->visit(this); - } - } - - AstExprTable* getTableHint(AstExpr* expr) - { - // unadorned table literal - if (AstExprTable* table = expr->as()) - return table; - - // setmetatable(table literal, ...) - if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) - if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") - if (AstExprTable* table = call->args.data[0]->as()) - return table; - - return nullptr; - } - - bool visit(AstStatLocal* node) override - { - // track local -> table association so that we can update table size prediction in assignField - if (node->vars.size == 1 && node->values.size == 1) - if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) - localToTable[node->vars.data[0]] = table; - - return true; - } - - bool visit(AstStatAssign* node) override - { - for (size_t i = 0; i < node->vars.size; ++i) - assign(node->vars.data[i]); - - for (size_t i = 0; i < node->values.size; ++i) - node->values.data[i]->visit(this); - - return false; - } - - bool visit(AstStatCompoundAssign* node) override - { - assign(node->var); - node->value->visit(this); - - return false; - } - - bool visit(AstStatFunction* node) override - { - assign(node->name); - node->func->visit(this); - - return false; - } - - Compiler* self; - }; - - struct ConstantVisitor : AstVisitor - { - ConstantVisitor(Compiler* self) - : self(self) - { - } - - void analyzeUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) - { - switch (op) - { - case AstExprUnary::Not: - if (arg.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !arg.isTruthful(); - } - break; - - case AstExprUnary::Minus: - if (arg.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = -arg.valueNumber; - } - break; - - case AstExprUnary::Len: - if (arg.type == Constant::Type_String) - { - result.type = Constant::Type_Number; - result.valueNumber = double(arg.valueString.size); - } - break; - - default: - LUAU_ASSERT(!"Unexpected unary operation"); - } - } - - bool constantsEqual(const Constant& la, const Constant& ra) - { - LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); - - switch (la.type) - { - case Constant::Type_Nil: - return ra.type == Constant::Type_Nil; - - case Constant::Type_Boolean: - return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; - - case Constant::Type_Number: - return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; - - case Constant::Type_String: - return ra.type == Constant::Type_String && la.valueString.size == ra.valueString.size && - memcmp(la.valueString.data, ra.valueString.data, la.valueString.size) == 0; - - default: - LUAU_ASSERT(!"Unexpected constant type in comparison"); - return false; - } - } - - void analyzeBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) - { - switch (op) - { - case AstExprBinary::Add: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber + ra.valueNumber; - } - break; - - case AstExprBinary::Sub: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - ra.valueNumber; - } - break; - - case AstExprBinary::Mul: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber * ra.valueNumber; - } - break; - - case AstExprBinary::Div: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber / ra.valueNumber; - } - break; - - case AstExprBinary::Mod: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; - } - break; - - case AstExprBinary::Pow: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = pow(la.valueNumber, ra.valueNumber); - } - break; - - case AstExprBinary::Concat: - break; - - case AstExprBinary::CompareNe: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareEq: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareLt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber < ra.valueNumber; - } - break; - - case AstExprBinary::CompareLe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber <= ra.valueNumber; - } - break; - - case AstExprBinary::CompareGt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber > ra.valueNumber; - } - break; - - case AstExprBinary::CompareGe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber >= ra.valueNumber; - } - break; - - case AstExprBinary::And: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? ra : la; - } - break; - - case AstExprBinary::Or: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? la : ra; - } - break; - - default: - LUAU_ASSERT(!"Unexpected binary operation"); - } - } - - Constant analyze(AstExpr* node) - { - Constant result; - result.type = Constant::Type_Unknown; - - if (AstExprGroup* expr = node->as()) - { - result = analyze(expr->expr); - } - else if (node->is()) - { - result.type = Constant::Type_Nil; - } - else if (AstExprConstantBool* expr = node->as()) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = expr->value; - } - else if (AstExprConstantNumber* expr = node->as()) - { - result.type = Constant::Type_Number; - result.valueNumber = expr->value; - } - else if (AstExprConstantString* expr = node->as()) - { - result.type = Constant::Type_String; - result.valueString = expr->value; - } - else if (AstExprLocal* expr = node->as()) - { - const Local* l = self->locals.find(expr->local); - - if (l && l->constant.type != Constant::Type_Unknown) - { - LUAU_ASSERT(!l->written); - result = l->constant; - } - } - else if (node->is()) - { - // nope - } - else if (node->is()) - { - // nope - } - else if (AstExprCall* expr = node->as()) - { - analyze(expr->func); - - for (size_t i = 0; i < expr->args.size; ++i) - analyze(expr->args.data[i]); - } - else if (AstExprIndexName* expr = node->as()) - { - analyze(expr->expr); - } - else if (AstExprIndexExpr* expr = node->as()) - { - analyze(expr->expr); - analyze(expr->index); - } - else if (AstExprFunction* expr = node->as()) - { - // this is necessary to propagate constant information in all child functions - expr->body->visit(this); - } - else if (AstExprTable* expr = node->as()) - { - for (size_t i = 0; i < expr->items.size; ++i) - { - const AstExprTable::Item& item = expr->items.data[i]; - - if (item.key) - analyze(item.key); - - analyze(item.value); - } - } - else if (AstExprUnary* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - analyzeUnary(result, expr->op, arg); - } - else if (AstExprBinary* expr = node->as()) - { - Constant la = analyze(expr->left); - Constant ra = analyze(expr->right); - - analyzeBinary(result, expr->op, la, ra); - } - else if (AstExprTypeAssertion* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - result = arg; - } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) - { - Constant cond = analyze(expr->condition); - Constant trueExpr = analyze(expr->trueExpr); - Constant falseExpr = analyze(expr->falseExpr); - if (cond.type != Constant::Type_Unknown) - { - result = cond.isTruthful() ? trueExpr : falseExpr; - } - } - else - { - LUAU_ASSERT(!"Unknown expression type"); - } - - if (result.type != Constant::Type_Unknown) - self->constants[node] = result; - - return result; - } - - bool visit(AstExpr* node) override - { - // note: we short-circuit the visitor traversal through any expression trees by returning false - // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression - analyze(node); - - return false; - } - - bool visit(AstStatLocal* node) override - { - // for values that match 1-1 we record the initializing expression for future analysis - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Local& l = self->locals[node->vars.data[i]]; - - l.init = node->values.data[i]; - } - - // all values that align wrt indexing are simple - we just match them 1-1 - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Constant arg = analyze(node->values.data[i]); - - if (arg.type != Constant::Type_Unknown) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - l.constant = arg; - } - } - - if (node->vars.size > node->values.size) - { - // if we have trailing variables, then depending on whether the last value is capable of returning multiple values - // (aka call or varargs), we either don't know anything about these vars, or we know they're nil - AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; - bool multRet = last && (last->is() || last->is()); - - for (size_t i = node->values.size; i < node->vars.size; ++i) - { - if (!multRet) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - { - l.constant.type = Constant::Type_Nil; - } - } - } - } - else - { - // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside - // them - for (size_t i = node->vars.size; i < node->values.size; ++i) - analyze(node->values.data[i]); - } - - return false; - } - - Compiler* self; - }; - struct FenvVisitor : AstVisitor { bool& getfenvUsed; @@ -3283,14 +2771,6 @@ struct Compiler return false; } - - bool visit(AstStatLocalFunction* node) override - { - // record local->function association for some optimizations - self->locals[node->name].func = node->func; - - return true; - } }; struct UndefinedLocalVisitor : AstVisitor @@ -3397,70 +2877,12 @@ struct Compiler std::vector upvals; }; - struct Constant - { - enum Type - { - Type_Unknown, - Type_Nil, - Type_Boolean, - Type_Number, - Type_String, - }; - - Type type = Type_Unknown; - - union - { - bool valueBoolean; - double valueNumber; - AstArray valueString = {}; - }; - - bool isTruthful() const - { - LUAU_ASSERT(type != Type_Unknown); - return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); - } - }; - struct Local { uint8_t reg = 0; bool allocated = false; bool captured = false; - bool written = false; - AstExpr* init = nullptr; uint32_t debugpc = 0; - Constant constant; - AstExprFunction* func = nullptr; - }; - - struct Global - { - bool writable = false; - bool written = false; - }; - - struct Builtin - { - AstName object; - AstName method; - - bool empty() const - { - return object == AstName() && method == AstName(); - } - - bool isGlobal(const char* name) const - { - return object == AstName() && method == name; - } - - bool isMethod(const char* table, const char* name) const - { - return object == table && method == name; - } }; struct LoopJump @@ -3482,194 +2904,6 @@ struct Compiler AstExpr* untilCondition; }; - Builtin getBuiltin(AstExpr* node) - { - if (AstExprLocal* expr = node->as()) - { - Local* l = locals.find(expr->local); - - return l && !l->written && l->init ? getBuiltin(l->init) : Builtin(); - } - else if (AstExprIndexName* expr = node->as()) - { - if (AstExprGlobal* object = expr->expr->as()) - { - Global* g = globals.find(object->name); - - return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); - } - else - { - return Builtin(); - } - } - else if (AstExprGlobal* expr = node->as()) - { - Global* g = globals.find(expr->name); - - return !g || !g->written ? Builtin{AstName(), expr->name} : Builtin(); - } - else - { - return Builtin(); - } - } - - int getBuiltinFunctionId(const Builtin& builtin) - { - if (builtin.empty()) - return -1; - - if (builtin.isGlobal("assert")) - return LBF_ASSERT; - - if (builtin.isGlobal("type")) - return LBF_TYPE; - - if (builtin.isGlobal("typeof")) - return LBF_TYPEOF; - - if (builtin.isGlobal("rawset")) - return LBF_RAWSET; - if (builtin.isGlobal("rawget")) - return LBF_RAWGET; - if (builtin.isGlobal("rawequal")) - return LBF_RAWEQUAL; - - if (builtin.isGlobal("unpack")) - return LBF_TABLE_UNPACK; - - if (builtin.object == "math") - { - if (builtin.method == "abs") - return LBF_MATH_ABS; - if (builtin.method == "acos") - return LBF_MATH_ACOS; - if (builtin.method == "asin") - return LBF_MATH_ASIN; - if (builtin.method == "atan2") - return LBF_MATH_ATAN2; - if (builtin.method == "atan") - return LBF_MATH_ATAN; - if (builtin.method == "ceil") - return LBF_MATH_CEIL; - if (builtin.method == "cosh") - return LBF_MATH_COSH; - if (builtin.method == "cos") - return LBF_MATH_COS; - if (builtin.method == "deg") - return LBF_MATH_DEG; - if (builtin.method == "exp") - return LBF_MATH_EXP; - if (builtin.method == "floor") - return LBF_MATH_FLOOR; - if (builtin.method == "fmod") - return LBF_MATH_FMOD; - if (builtin.method == "frexp") - return LBF_MATH_FREXP; - if (builtin.method == "ldexp") - return LBF_MATH_LDEXP; - if (builtin.method == "log10") - return LBF_MATH_LOG10; - if (builtin.method == "log") - return LBF_MATH_LOG; - if (builtin.method == "max") - return LBF_MATH_MAX; - if (builtin.method == "min") - return LBF_MATH_MIN; - if (builtin.method == "modf") - return LBF_MATH_MODF; - if (builtin.method == "pow") - return LBF_MATH_POW; - if (builtin.method == "rad") - return LBF_MATH_RAD; - if (builtin.method == "sinh") - return LBF_MATH_SINH; - if (builtin.method == "sin") - return LBF_MATH_SIN; - if (builtin.method == "sqrt") - return LBF_MATH_SQRT; - if (builtin.method == "tanh") - return LBF_MATH_TANH; - if (builtin.method == "tan") - return LBF_MATH_TAN; - if (builtin.method == "clamp") - return LBF_MATH_CLAMP; - if (builtin.method == "sign") - return LBF_MATH_SIGN; - if (builtin.method == "round") - return LBF_MATH_ROUND; - } - - if (builtin.object == "bit32") - { - if (builtin.method == "arshift") - return LBF_BIT32_ARSHIFT; - if (builtin.method == "band") - return LBF_BIT32_BAND; - if (builtin.method == "bnot") - return LBF_BIT32_BNOT; - if (builtin.method == "bor") - return LBF_BIT32_BOR; - if (builtin.method == "bxor") - return LBF_BIT32_BXOR; - if (builtin.method == "btest") - return LBF_BIT32_BTEST; - if (builtin.method == "extract") - return LBF_BIT32_EXTRACT; - if (builtin.method == "lrotate") - return LBF_BIT32_LROTATE; - if (builtin.method == "lshift") - return LBF_BIT32_LSHIFT; - if (builtin.method == "replace") - return LBF_BIT32_REPLACE; - if (builtin.method == "rrotate") - return LBF_BIT32_RROTATE; - if (builtin.method == "rshift") - return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz") - return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz") - return LBF_BIT32_COUNTRZ; - } - - if (builtin.object == "string") - { - if (builtin.method == "byte") - return LBF_STRING_BYTE; - if (builtin.method == "char") - return LBF_STRING_CHAR; - if (builtin.method == "len") - return LBF_STRING_LEN; - if (builtin.method == "sub") - return LBF_STRING_SUB; - } - - if (builtin.object == "table") - { - if (builtin.method == "insert") - return LBF_TABLE_INSERT; - if (builtin.method == "unpack") - return LBF_TABLE_UNPACK; - } - - if (options.vectorCtor) - { - if (options.vectorLib) - { - if (builtin.isMethod(options.vectorLib, options.vectorCtor)) - return LBF_VECTOR; - } - else - { - if (builtin.isGlobal(options.vectorCtor)) - return LBF_VECTOR; - } - } - - return -1; - } - BytecodeBuilder& bytecode; CompileOptions options; @@ -3677,8 +2911,9 @@ struct Compiler DenseHashMap functions; DenseHashMap locals; DenseHashMap globals; + DenseHashMap variables; DenseHashMap constants; - DenseHashMap> predictedTableSize; + DenseHashMap tableShapes; unsigned int regTop = 0; unsigned int stackSize = 0; @@ -3699,23 +2934,18 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + assignMutable(compiler.globals, names, options.mutableGlobals); - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; + // this pass analyzes mutability of locals/globals and associates locals with their initial values + trackValues(compiler.globals, compiler.variables, root); - // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written - Compiler::AssignmentVisitor assignmentVisitor(&compiler); - root->visit(&assignmentVisitor); - - // this visitor traverses the AST to analyze constantness of expressions, filling constants[] and Local::constant/Local::init if (options.optimizationLevel >= 1) { - Compiler::ConstantVisitor constantVisitor(&compiler); - root->visit(&constantVisitor); + // this pass analyzes constantness of expressions + foldConstants(compiler.constants, compiler.variables, root); + + // this pass analyzes table assignments to estimate table shapes for initially empty tables + predictTableShapes(compiler.tableShapes, root); } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found @@ -3734,8 +2964,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName for (AstExprFunction* expr : functions) compiler.compileFunction(expr); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), /* self= */ nullptr, - AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); + AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); bytecode.setMainFunction(mainid); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp new file mode 100644 index 000000000..60a7c1692 --- /dev/null +++ b/Compiler/src/ConstantFolding.cpp @@ -0,0 +1,394 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstantFolding.h" + +#include + +namespace Luau +{ +namespace Compile +{ + +static bool constantsEqual(const Constant& la, const Constant& ra) +{ + LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); + + switch (la.type) + { + case Constant::Type_Nil: + return ra.type == Constant::Type_Nil; + + case Constant::Type_Boolean: + return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; + + case Constant::Type_Number: + return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; + + case Constant::Type_String: + return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0; + + default: + LUAU_ASSERT(!"Unexpected constant type in comparison"); + return false; + } +} + +static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) +{ + switch (op) + { + case AstExprUnary::Not: + if (arg.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !arg.isTruthful(); + } + break; + + case AstExprUnary::Minus: + if (arg.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = -arg.valueNumber; + } + break; + + case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.stringLength); + } + break; + + default: + LUAU_ASSERT(!"Unexpected unary operation"); + } +} + +static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) +{ + switch (op) + { + case AstExprBinary::Add: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber + ra.valueNumber; + } + break; + + case AstExprBinary::Sub: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - ra.valueNumber; + } + break; + + case AstExprBinary::Mul: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber * ra.valueNumber; + } + break; + + case AstExprBinary::Div: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber / ra.valueNumber; + } + break; + + case AstExprBinary::Mod: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; + } + break; + + case AstExprBinary::Pow: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = pow(la.valueNumber, ra.valueNumber); + } + break; + + case AstExprBinary::Concat: + break; + + case AstExprBinary::CompareNe: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareEq: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareLt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber < ra.valueNumber; + } + break; + + case AstExprBinary::CompareLe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber <= ra.valueNumber; + } + break; + + case AstExprBinary::CompareGt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber > ra.valueNumber; + } + break; + + case AstExprBinary::CompareGe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber >= ra.valueNumber; + } + break; + + case AstExprBinary::And: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? ra : la; + } + break; + + case AstExprBinary::Or: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? la : ra; + } + break; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + } +} + +struct ConstantVisitor : AstVisitor +{ + DenseHashMap& constants; + DenseHashMap& variables; + + DenseHashMap locals; + + ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables) + : constants(constants) + , variables(variables) + , locals(nullptr) + { + } + + Constant analyze(AstExpr* node) + { + Constant result; + result.type = Constant::Type_Unknown; + + if (AstExprGroup* expr = node->as()) + { + result = analyze(expr->expr); + } + else if (node->is()) + { + result.type = Constant::Type_Nil; + } + else if (AstExprConstantBool* expr = node->as()) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = expr->value; + } + else if (AstExprConstantNumber* expr = node->as()) + { + result.type = Constant::Type_Number; + result.valueNumber = expr->value; + } + else if (AstExprConstantString* expr = node->as()) + { + result.type = Constant::Type_String; + result.valueString = expr->value.data; + result.stringLength = unsigned(expr->value.size); + } + else if (AstExprLocal* expr = node->as()) + { + const Constant* l = locals.find(expr->local); + + if (l) + result = *l; + } + else if (node->is()) + { + // nope + } + else if (node->is()) + { + // nope + } + else if (AstExprCall* expr = node->as()) + { + analyze(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + analyze(expr->args.data[i]); + } + else if (AstExprIndexName* expr = node->as()) + { + analyze(expr->expr); + } + else if (AstExprIndexExpr* expr = node->as()) + { + analyze(expr->expr); + analyze(expr->index); + } + else if (AstExprFunction* expr = node->as()) + { + // this is necessary to propagate constant information in all child functions + expr->body->visit(this); + } + else if (AstExprTable* expr = node->as()) + { + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + analyze(item.key); + + analyze(item.value); + } + } + else if (AstExprUnary* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + if (arg.type != Constant::Type_Unknown) + foldUnary(result, expr->op, arg); + } + else if (AstExprBinary* expr = node->as()) + { + Constant la = analyze(expr->left); + Constant ra = analyze(expr->right); + + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + foldBinary(result, expr->op, la, ra); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + result = arg; + } + else if (AstExprIfElse* expr = node->as()) + { + Constant cond = analyze(expr->condition); + Constant trueExpr = analyze(expr->trueExpr); + Constant falseExpr = analyze(expr->falseExpr); + + if (cond.type != Constant::Type_Unknown) + result = cond.isTruthful() ? trueExpr : falseExpr; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + } + + if (result.type != Constant::Type_Unknown) + constants[node] = result; + + return result; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression + analyze(node); + + return false; + } + + bool visit(AstStatLocal* node) override + { + // all values that align wrt indexing are simple - we just match them 1-1 + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + { + Constant arg = analyze(node->values.data[i]); + + if (arg.type != Constant::Type_Unknown) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]] = arg; + v->constant = true; + } + } + } + + if (node->vars.size > node->values.size) + { + // if we have trailing variables, then depending on whether the last value is capable of returning multiple values + // (aka call or varargs), we either don't know anything about these vars, or we know they're nil + AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; + bool multRet = last && (last->is() || last->is()); + + if (!multRet) + { + for (size_t i = node->values.size; i < node->vars.size; ++i) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]].type = Constant::Type_Nil; + v->constant = true; + } + } + } + } + else + { + // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside + // them + for (size_t i = node->vars.size; i < node->values.size; ++i) + analyze(node->values.data[i]); + } + + return false; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root) +{ + ConstantVisitor visitor{constants, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h new file mode 100644 index 000000000..c0e63539b --- /dev/null +++ b/Compiler/src/ConstantFolding.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +namespace Compile +{ + +struct Constant +{ + enum Type + { + Type_Unknown, + Type_Nil, + Type_Boolean, + Type_Number, + Type_String, + }; + + Type type = Type_Unknown; + unsigned int stringLength = 0; + + union + { + bool valueBoolean; + double valueNumber; + char* valueString = nullptr; // length stored in stringLength + }; + + bool isTruthful() const + { + LUAU_ASSERT(type != Type_Unknown); + return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); + } + + AstArray getString() const + { + LUAU_ASSERT(type == Type_String); + return {valueString, stringLength}; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp new file mode 100644 index 000000000..7d99f2228 --- /dev/null +++ b/Compiler/src/TableShape.cpp @@ -0,0 +1,129 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "TableShape.h" + +namespace Luau +{ +namespace Compile +{ + +static AstExprTable* getTableHint(AstExpr* expr) +{ + // unadorned table literal + if (AstExprTable* table = expr->as()) + return table; + + // setmetatable(table literal, ...) + if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) + if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") + if (AstExprTable* table = call->args.data[0]->as()) + return table; + + return nullptr; +} + +struct ShapeVisitor : AstVisitor +{ + struct Hasher + { + size_t operator()(const std::pair& p) const + { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + DenseHashMap& shapes; + + DenseHashMap tables; + DenseHashSet, Hasher> fields; + + ShapeVisitor(DenseHashMap& shapes) + : shapes(shapes) + , tables(nullptr) + , fields(std::pair()) + { + } + + void assignField(AstExpr* expr, AstName index) + { + if (AstExprLocal* lv = expr->as()) + { + if (AstExprTable** table = tables.find(lv->local)) + { + std::pair field = {*table, index}; + + if (!fields.contains(field)) + { + fields.insert(field); + shapes[*table].hashSize += 1; + } + } + } + } + + void assignField(AstExpr* expr, AstExpr* index) + { + AstExprLocal* lv = expr->as(); + AstExprConstantNumber* number = index->as(); + + if (lv && number) + { + if (AstExprTable** table = tables.find(lv->local)) + { + TableShape& shape = shapes[*table]; + + if (number->value == double(shape.arraySize + 1)) + shape.arraySize += 1; + } + } + } + + void assign(AstExpr* var) + { + if (AstExprIndexName* index = var->as()) + { + assignField(index->expr, index->index); + } + else if (AstExprIndexExpr* index = var->as()) + { + assignField(index->expr, index->index); + } + } + + bool visit(AstStatLocal* node) override + { + // track local -> table association so that we can update table size prediction in assignField + if (node->vars.size == 1 && node->values.size == 1) + if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) + tables[node->vars.data[0]] = table; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root) +{ + ShapeVisitor visitor{shapes}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.h b/Compiler/src/TableShape.h new file mode 100644 index 000000000..f30853a7f --- /dev/null +++ b/Compiler/src/TableShape.h @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +struct TableShape +{ + unsigned int arraySize = 0; + unsigned int hashSize = 0; +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.cpp b/Compiler/src/ValueTracking.cpp new file mode 100644 index 000000000..0bfaf9b3c --- /dev/null +++ b/Compiler/src/ValueTracking.cpp @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ValueTracking.h" + +#include "Luau/Lexer.h" + +namespace Luau +{ +namespace Compile +{ + +struct ValueVisitor : AstVisitor +{ + DenseHashMap& globals; + DenseHashMap& variables; + + ValueVisitor(DenseHashMap& globals, DenseHashMap& variables) + : globals(globals) + , variables(variables) + { + } + + void assign(AstExpr* var) + { + if (AstExprLocal* lv = var->as()) + { + variables[lv->local].written = true; + } + else if (AstExprGlobal* gv = var->as()) + { + globals[gv->name] = Global::Written; + } + else + { + // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 + var->visit(this); + } + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + variables[node->vars.data[i]].init = node->values.data[i]; + + for (size_t i = node->values.size; i < node->vars.size; ++i) + variables[node->vars.data[i]].init = nullptr; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + node->value->visit(this); + + return false; + } + + bool visit(AstStatLocalFunction* node) override + { + variables[node->name].init = node->func; + + return true; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals) +{ + if (AstName name = names.get("_G"); name.value) + globals[name] = Global::Mutable; + + if (mutableGlobals) + for (const char** ptr = mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + globals[name] = Global::Mutable; +} + +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root) +{ + ValueVisitor visitor{globals, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.h b/Compiler/src/ValueTracking.h new file mode 100644 index 000000000..fc74c84aa --- /dev/null +++ b/Compiler/src/ValueTracking.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +class AstNameTable; +} + +namespace Luau +{ +namespace Compile +{ + +enum class Global +{ + Default = 0, + Mutable, // builtin that has contents unknown at compile time, blocks GETIMPORT for chains + Written, // written in the code which means we can't reason about the value +}; + +struct Variable +{ + AstExpr* init = nullptr; // initial value of the variable; filled by trackValues + bool written = false; // is the variable ever assigned to? filled by trackValues + bool constant = false; // is the variable's value a compile-time constant? filled by constantFold +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals); +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root); + +inline Global getGlobalState(const DenseHashMap& globals, AstName name) +{ + const Global* it = globals.find(name); + + return it ? *it : Global::Default; +} + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 5dd486aaa..bafe75948 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -29,7 +29,15 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/Builtins.cpp + Compiler/src/ConstantFolding.cpp + Compiler/src/TableShape.cpp + Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp + Compiler/src/Builtins.h + Compiler/src/ConstantFolding.h + Compiler/src/TableShape.h + Compiler/src/ValueTracking.h ) # Luau.Analysis Sources diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index eb47971a8..3cce7665d 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,7 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) /* @@ -545,11 +544,8 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e if (!oldactive) resetbit(L->stackstate, THREAD_ACTIVEBIT); - if (FFlag::LuauCcallRestoreFix) - { - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. - L->nCcalls = oldnCcalls; - } + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) @@ -564,10 +560,6 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - if (!FFlag::LuauCcallRestoreFix) - { - L->nCcalls = oldnCcalls; - } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 648785697..4178eda40 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,6 +6,8 @@ #include "lmem.h" #include "lgc.h" +LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) + Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); @@ -113,14 +115,16 @@ void luaF_freeupval(lua_State* L, UpVal* uv) void luaF_close(lua_State* L, StkId level) { UpVal* uv; - global_State* g = L->global; + global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); L->openupval = uv->next; /* remove from `open' list */ - if (isdead(g, o)) + if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + { luaF_freeupval(L, uv); /* free upvalue */ + } else { unlinkupval(uv); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 0b3054ae4..74a8aa8a1 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) - /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1406,20 +1404,10 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - if (FFlag::LuauStrPackUBCastFix) - { - long long n = (long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, (unsigned long long)n, h.islittle, size, 0); - } - else - { - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); - } + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); break; } case Kfloat: diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 6d40406cd..5d162ab99 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -111,15 +111,10 @@ end -- multiplies two matrices function MMulti(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do - M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; j = j + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; end - - i = i + 1 end return M; end @@ -127,28 +122,27 @@ end -- multiplies matrix with vector function VMulti(M, V) local Vect = {}; - local i = 1; - while i <= 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; i = i + 1 end + for i = 1,4 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; + end return Vect; end function VMulti2(M, V) local Vect = {}; - local i = 1; - while i < 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; i = i + 1 end + for i = 1,3 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; + end return Vect; end -- add to matrices function MAdd(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do M[i][j] = M1[i][j] + M2[i][j]; j = j + 1 end - - i = i + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][j] + M2[i][j]; + end end return M; end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 210db7eea..211e1be1f 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1938,7 +1938,6 @@ return target(b@1 TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function bar(a: number) return -a end @@ -1954,7 +1953,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function foo() return 1 end @@ -2538,10 +2536,6 @@ TEST_CASE("autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { check(R"( local temp = false local even = true; @@ -2614,7 +2608,6 @@ a = if temp then even elseif true then temp else e@9 CHECK(ac.entryMap.count("then") == 0); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); - } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") @@ -2681,4 +2674,58 @@ local r4 = t:bar1(@4) CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +{ + ScopedFastFlag flag("LuauMissingFollowACMetatables", true); + check(R"( +--!strict +local Class = {} +Class.__index = Class +type Class = typeof(setmetatable({} :: { x: number }, Class)) +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end +function Class.getx(self: Class) + return self.x +end +function test() + local c = Class.new(42) + local n = c:@1 + print(n) +end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("getx")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 95811b3f4..8eed953fd 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,9 +603,9 @@ RETURN R0 1 )"); } -TEST_CASE("EmptyTableHashSizePredictionOptimization") +TEST_CASE("TableSizePredictionBasic") { - const char* hashSizeSource = R"( + CHECK_EQ("\n" + compileFunction0(R"( local t = {} t.a = 1 t.b = 1 @@ -616,36 +616,8 @@ t.f = 1 t.g = 1 t.h = 1 t.i = 1 -)"; - - const char* hashSizeSource2 = R"( -local t = {} -t.x = 1 -t.x = 2 -t.x = 3 -t.x = 4 -t.x = 5 -t.x = 6 -t.x = 7 -t.x = 8 -t.x = 9 -)"; - - const char* arraySizeSource = R"( -local t = {} -t[1] = 1 -t[2] = 1 -t[3] = 1 -t[4] = 1 -t[5] = 1 -t[6] = 1 -t[7] = 1 -t[8] = 1 -t[9] = 1 -t[10] = 1 -)"; - - CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"( +)"), + R"( NEWTABLE R0 16 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -668,7 +640,19 @@ SETTABLEKS R1 R0 K8 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t.x = 1 +t.x = 2 +t.x = 3 +t.x = 4 +t.x = 5 +t.x = 6 +t.x = 7 +t.x = 8 +t.x = 9 +)"), + R"( NEWTABLE R0 1 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -691,7 +675,20 @@ SETTABLEKS R1 R0 K0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t[1] = 1 +t[2] = 1 +t[3] = 1 +t[4] = 1 +t[5] = 1 +t[6] = 1 +t[7] = 1 +t[8] = 1 +t[9] = 1 +t[10] = 1 +)"), + R"( NEWTABLE R0 0 10 LOADN R1 1 SETTABLEN R1 R0 1 @@ -717,6 +714,27 @@ RETURN R0 0 )"); } +TEST_CASE("TableSizePredictionObject") +{ + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +t.field = 1 +function t:getfield() + return self.field +end +return t +)", + 1), + R"( +NEWTABLE R0 2 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +DUPCLOSURE R1 K1 +SETTABLEKS R1 R0 K2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionSetMetatable") { CHECK_EQ("\n" + compileFunction0(R"( @@ -1031,9 +1049,6 @@ RETURN R0 1 TEST_CASE("IfElseExpression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - // codegen for a true constant condition CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( LOADN R0 10 @@ -3058,7 +3073,7 @@ RETURN R0 0 // table variants (indexed by string, number, variable) CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( -NEWTABLE R0 1 0 +NEWTABLE R0 0 0 GETTABLEKS R1 R0 K0 ADDK R1 R1 K1 SETTABLEKS R1 R0 K0 @@ -3066,7 +3081,7 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( -NEWTABLE R0 0 1 +NEWTABLE R0 0 0 GETTABLEN R1 R0 1 ADDK R1 R1 K0 SETTABLEN R1 R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 663b329ee..5222af33e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -366,15 +366,11 @@ TEST_CASE("PCall") TEST_CASE("Pack") { - ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; - runConformance("tpack.lua"); } TEST_CASE("Vector") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - lua_CompileOptions copts = {}; copts.optimizationLevel = 1; copts.debugLevel = 1; @@ -861,15 +857,11 @@ TEST_CASE("ExceptionObject") TEST_CASE("IfElseExpression") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("ifelseexpr.lua"); } TEST_CASE("TagMethodError") { - ScopedFastFlag sff{"LuauCcallRestoreFix", true}; - runConformance("tmerror.lua", [](lua_State* L) { auto* cb = lua_callbacks(L); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ca4281a0b..c74bfa272 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -191,7 +191,7 @@ ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& par return result; } -ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) +ParseResult Fixture::matchParseError(const std::string& source, const std::string& message, std::optional location) { ParseOptions options; options.allowDeclarationSyntax = true; @@ -203,6 +203,9 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin CHECK_EQ(result.errors.front().getMessage(), message); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + return result; } diff --git a/tests/Fixture.h b/tests/Fixture.h index e01632eab..ab852ef6d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -106,7 +106,7 @@ struct Fixture /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); - ParseResult matchParseError(const std::string& source, const std::string& message); + ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 5abcb09ac..e91356515 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1255,7 +1255,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " @@ -1266,7 +1265,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " @@ -1276,7 +1274,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements" TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " @@ -1286,7 +1283,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1 TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError( @@ -1962,6 +1958,37 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + AstStat* stat = parse(R"( +type A = {} +type B = {} +type C = {} +type D = {} +type E = {} +type F = (T...) -> U... +type G = (U...) -> T... + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); + matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); + matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2455,10 +2482,19 @@ do end CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true}; + + ParseResult result = tryParse(R"( +type Y = (T...) -> U... + )"); + CHECK_EQ(1, result.errors.size()); +} +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ { AstStat* stat = parse("return if true then 1 else 2"); @@ -2524,9 +2560,4 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } -TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") -{ - matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); -} - TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 80a258f5b..445ee5329 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -338,6 +338,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -353,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -500,6 +502,24 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") +{ + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( + local function f(a: number, b: string) end + local function test(...: T...): U... + f(...) + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("test(...: T...): U...", toStringNamedFunction("test", *ftv)); +} + TEST_CASE("toStringNamedFunction_unit_f") { TypePackVar empty{TypePack{}}; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 47c3883c1..ac5be859b 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -421,8 +421,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") { - ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); - std::string code = "local a = if 1 then 2 else 3"; CHECK_EQ(code, transpile(code).code); @@ -641,4 +639,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string") CHECK_EQ("'hello'", toString(expr)); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + std::string code = R"( +type Packed = (T, U, V...)->(W...) +local a: Packed + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 275782b30..86165814c 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -497,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") CHECK(arrayTable->indexer); CHECK(isInArena(array.type, mod.interfaceTypes)); - CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); + CHECK_EQ(array.typeParams[0].ty, arrayTable->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 503b613f4..d76b920bb 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1031,9 +1031,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if v then v else tostring(v) @@ -1048,9 +1045,6 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if not v then tostring(v) else v @@ -1065,9 +1059,6 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function returnOne(x) return 1 @@ -1119,6 +1110,25 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = { x: { y: number }? } + + local function f(t: T?) + if t and t.x then + local foo = t.x.y + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({5, 32}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 68dc1b4fa..94cfb6437 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -360,6 +360,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { ScopedFastFlag sffs[] = { {"LuauParseSingletonTypes", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -369,7 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", toString(result.errors[0])); } @@ -423,4 +424,27 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + {"LuauIfElseExpectedType2", true}, + {"LuauIfElseBranchTypeUnion", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag = 'dog', dogfood = 'other' } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 80f40407e..27cda1463 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { - CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); + CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); - CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); + CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}})); } TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") @@ -1155,7 +1155,8 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_mu TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t.m() end )"); @@ -1165,13 +1166,38 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t:m() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t.m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + // This unit test could be flaky if the fix has regressed. TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") { @@ -1439,8 +1465,13 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") CHECK_EQ("{| |}", toString(mp->subType)); } -TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") +TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; + CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1448,6 +1479,25 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + // Should t now have an indexer? + // It would if the assignment to rt was correctly typed. + CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer") +{ + CheckResult result = check(R"( + type StringToStringMap = { [string]: string } + function mkrt() return { ["foo"] = 1 } end + local rt: StringToStringMap = mkrt() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); @@ -1467,7 +1517,10 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1480,7 +1533,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); + CHECK_EQ("{ a: number }", toString(tm->givenType, o)); } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") @@ -1536,8 +1589,11 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {{x = 1, y = 2, z = 3}} - local vec1 = {{x = 1}} + function mkvec3() return {x = 1, y = 2, z = 3} end + function mkvec1() return {x = 1} end + + local vec3 = {mkvec3()} + local vec1 = {mkvec1()} vec1 = vec3 )"); @@ -1620,7 +1676,8 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_ { CheckResult result = check(R"( --!strict - local A = {"value"} + function mkA() return {"value"} end + local A = mkA() A.B = "Hello" )"); @@ -1668,7 +1725,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") --!strict local function f() - local t = { x = 1 } + local function mkt() return { x = 1 } end + local t = mkt() function t.a() end function t.b() end @@ -1995,7 +2053,10 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2010,7 +2071,7 @@ local c2: typeof(a2) = b2 LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' caused by: - Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); @@ -2018,7 +2079,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); } @@ -2026,7 +2087,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } @@ -2059,6 +2120,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -2077,7 +2139,7 @@ local y: number = tmp.p.y LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' caused by: - Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); + Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") @@ -2103,7 +2165,10 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { - ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + ScopedFastFlag sff[]{ + {"LuauFixRecursiveMetatableCall", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local b @@ -2112,7 +2177,7 @@ b() )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f70f3b1c8..7a056af52 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4525,7 +4525,9 @@ f(function(x) print(x) end) } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4549,7 +4551,7 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") @@ -4689,6 +4691,18 @@ a = setmetatable(a, { __call = function(x) end }) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") +{ + ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true}; + + CheckResult result = check(R"( +local function f(a: (number, number) -> number) return a(1, 3) end +f(((function(a, b) return a + b end))) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "refine_and_or") { CheckResult result = check(R"( @@ -4743,46 +4757,75 @@ local c: X TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - CheckResult result = check(R"(local a = if true then "true" else "false")"); - LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"(local a = if true then "true" else "false")"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + // Test expression containing elseif + CheckResult result = check(R"( +local a = if false then "a" elseif false then "b" else "c" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") +{ + ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true}; { - // Test expression containing elseif - CheckResult result = check(R"( -local a = if false then "a" elseif false then "b" else "c" - )"); + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); } } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - { - CheckResult result = check(R"(local a = if true then "true" else 42)"); - // We currently require both true/false expressions to unify to the same type. However, we do intend to lift - // this restriction in the future. - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"( +type X = {number | string} +local a: X = if true then {"1", 2, 3} else {4, 5, 6} +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "{number | string}"); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true }; + + CheckResult result = check(R"( +local a: number? = if true then 1 else nil +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + + CheckResult result = check(R"( +local function times(n: any, f: () -> T) + local result: {T} = {} + local res = f() + table.insert(result, if true then res else n) + return result +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_error_addition") @@ -5039,4 +5082,51 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +{ + ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true}; + + CheckResult result = check(R"( +local function getIt() + local y + y = setmetatable({}, y) + return y +end +local a = getIt() +local b = getIt() +local c = a or b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5d37b032a..d4878d149 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -621,4 +621,328 @@ type Other = Packed CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = 2, b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y = { a: T } + +local a: Y = { a = 2 } +local b: Y<> = { a = "s" } +local c: Y = { a = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = "h", b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y string> = { a: T, b: U } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y string>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U, c: V } + +local a: Y +local b: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: (U...) -> T } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U..., b: (T...) -> V... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T, U, V...) -> W... } +local a: Y +local b: Y +local c: Y +local d: Y ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); + CHECK_EQ(toString(requireType("d")), "Y ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T } +local a: Y = { a = 2 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T) -> U... } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Y' expects at least 1 type argument, but none are specified"); + + result = check(R"( +type Packed = (T) -> T +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Y = { a: T } +local a: Y + )"); + + LUAU_REQUIRE_ERRORS(result); + + result = check(R"( +type Y = { a: T } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + fileResolver.source["Module/Types"] = R"( +export type A = { a: T, b: U } +export type B = { a: T, b: U } +export type C string> = { a: T, b: U } +export type D = { a: T, b: U, c: V } +export type E = { a: (T...) -> () } +export type F = { a: T, b: (U...) -> T } +export type G = { b: (U...) -> T... } +export type H = { b: (T...) -> T... } +return {} + )"; + + CheckResult resultTypes = frontend.check("Module/Types"); + LUAU_REQUIRE_NO_ERRORS(resultTypes); + + fileResolver.source["Module/Users"] = R"( +local Types = require(script.Parent.Types) + +local a: Types.A +local b: Types.B +local c: Types.C +local d: Types.D +local e: Types.E<> +local eVoid: Types.E<()> +local f: Types.F +local g: Types.G<...number> +local h: Types.H<> + )"; + + CheckResult resultUsers = frontend.check("Module/Users"); + LUAU_REQUIRE_NO_ERRORS(resultUsers); + + CHECK_EQ(toString(requireType("Module/Users", "a")), "A"); + CHECK_EQ(toString(requireType("Module/Users", "b")), "B"); + CHECK_EQ(toString(requireType("Module/Users", "c")), "C string>"); + CHECK_EQ(toString(requireType("Module/Users", "d")), "D"); + CHECK_EQ(toString(requireType("Module/Users", "e")), "E"); + CHECK_EQ(toString(requireType("Module/Users", "eVoid")), "E<>"); + CHECK_EQ(toString(requireType("Module/Users", "f")), "F"); + CHECK_EQ(toString(requireType("Module/Users", "g")), "G<...number, ()>"); + CHECK_EQ(toString(requireType("Module/Users", "h")), "H<>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = (T...) -> number +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "(...string) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type A = (T, V...) -> (U, W...) +type B = A +type C = A + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("B"), {true}), "(string, ...any) -> (number, ...any)"); + CHECK_EQ(toString(*lookupType("C"), {true}), "(string, boolean) -> (number, boolean)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type F ()> = (K) -> V +type R = { m: F } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") +{ + ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +local a: () -> (number, ...string) +local b: () -> (number, ...boolean) +a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' +caused by: + Type 'boolean' could not be converted into 'string')"); +} + TEST_SUITE_END();