From 08c66ef2e17f685e2a2b3195909ac28816feb19a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 4 Nov 2021 19:07:18 -0700 Subject: [PATCH 01/14] Sync to upstream/release/502 Changes: - Support for time tracing for analysis/compiler (not currently exposed through CLI) - Support for type pack arguments in type aliases (#83) - Basic support for require(path) in luau-analyze - Add a lint warning for table.move with 0 index as part of TableOperation lint - Remove last STL dependency from Luau.VM - Minor VS2022 performance tuning Co-authored-by: Rodactor --- .gitignore | 7 + Analysis/include/Luau/BuiltinDefinitions.h | 3 +- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/FileResolver.h | 58 +- Analysis/include/Luau/Module.h | 10 +- Analysis/include/Luau/ModuleResolver.h | 6 - Analysis/include/Luau/RequireTracer.h | 5 +- Analysis/include/Luau/Scope.h | 67 ++ Analysis/include/Luau/Substitution.h | 14 +- Analysis/include/Luau/TypeInfer.h | 56 +- Analysis/include/Luau/TypePack.h | 3 +- Analysis/include/Luau/TypeVar.h | 22 +- Analysis/include/Luau/Unifier.h | 18 +- Analysis/src/AstQuery.cpp | 17 +- Analysis/src/Autocomplete.cpp | 49 +- Analysis/src/BuiltinDefinitions.cpp | 95 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 21 - Analysis/src/Error.cpp | 108 ++- Analysis/src/Frontend.cpp | 51 +- Analysis/src/IostreamHelpers.cpp | 18 +- Analysis/src/JsonEncoder.cpp | 32 +- Analysis/src/Linter.cpp | 19 +- Analysis/src/Module.cpp | 44 +- Analysis/src/RequireTracer.cpp | 215 ++++- Analysis/src/Scope.cpp | 123 +++ Analysis/src/Substitution.cpp | 34 +- Analysis/src/ToString.cpp | 131 ++- Analysis/src/Transpiler.cpp | 38 +- Analysis/src/TypeAttach.cpp | 124 ++- Analysis/src/TypeInfer.cpp | 567 ++++++------ Analysis/src/TypePack.cpp | 13 + Analysis/src/TypeUtils.cpp | 18 +- Analysis/src/TypeVar.cpp | 36 +- Analysis/src/Unifier.cpp | 500 +++++++++-- Ast/include/Luau/Ast.h | 33 +- Ast/include/Luau/DenseHash.h | 5 +- Ast/include/Luau/Parser.h | 8 +- Ast/include/Luau/TimeTrace.h | 223 +++++ Ast/src/Ast.cpp | 37 +- Ast/src/Parser.cpp | 145 ++- Ast/src/TimeTrace.cpp | 248 ++++++ CLI/Analyze.cpp | 18 +- Compiler/src/Compiler.cpp | 10 + Sources.cmake | 5 + VM/src/ldo.cpp | 10 +- VM/src/lgc.cpp | 191 +++- VM/src/ltablib.cpp | 22 - VM/src/lvmexecute.cpp | 12 +- VM/src/lvmload.cpp | 34 +- bench/tests/deltablue.lua | 934 -------------------- tests/Autocomplete.test.cpp | 853 +++++++++--------- tests/Fixture.cpp | 49 + tests/Fixture.h | 2 + tests/Frontend.test.cpp | 31 +- tests/Linter.test.cpp | 9 +- tests/Module.test.cpp | 1 + tests/NonstrictMode.test.cpp | 1 + tests/Parser.test.cpp | 15 + tests/RequireTracer.test.cpp | 68 +- tests/ToString.test.cpp | 3 +- tests/TypeInfer.aliases.test.cpp | 557 ++++++++++++ tests/TypeInfer.provisional.test.cpp | 36 +- tests/TypeInfer.refinements.test.cpp | 72 +- tests/TypeInfer.tables.test.cpp | 230 +++-- tests/TypeInfer.test.cpp | 563 +----------- tests/TypeInfer.tryUnify.test.cpp | 1 + tests/TypeInfer.typePacks.cpp | 366 ++++++++ tests/TypeInfer.unionTypes.test.cpp | 16 +- tests/TypeVar.test.cpp | 1 + tools/tracegraph.py | 95 ++ 70 files changed, 4475 insertions(+), 2952 deletions(-) create mode 100644 .gitignore create mode 100644 Analysis/include/Luau/Scope.h create mode 100644 Analysis/src/Scope.cpp create mode 100644 Ast/include/Luau/TimeTrace.h create mode 100644 Ast/src/TimeTrace.cpp delete mode 100644 bench/tests/deltablue.lua create mode 100644 tests/TypeInfer.aliases.test.cpp create mode 100644 tools/tracegraph.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..0b2422ced --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +^build/ +^coverage/ +^fuzz/luau.pb.* +^crash-* +^default.prof* +^fuzz-* +^luau$ diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 8f17fff65..57a1907a5 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "TypeInfer.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" namespace Luau { diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 946bc9288..ac6f13e96 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -120,6 +120,7 @@ struct IncorrectGenericParameterCount Name name; TypeFun typeFun; size_t actualParameters; + size_t actualPackParameters; bool operator==(const IncorrectGenericParameterCount& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 71f9464b8..a05ec5e91 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -25,51 +25,39 @@ struct SourceCode Type type; }; +struct ModuleInfo +{ + ModuleName name; + bool optional = false; +}; + struct FileResolver { virtual ~FileResolver() {} - /** Fetch the source code associated with the provided ModuleName. - * - * FIXME: This requires a string copy! - * - * @returns The actual Lua code on success. - * @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error. - */ virtual std::optional readSource(const ModuleName& name) = 0; - /** Does the module exist? - * - * Saves a string copy over reading the source and throwing it away. - */ - virtual bool moduleExists(const ModuleName& name) const = 0; - - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - - /** Given a valid module name and a string of arbitrary data, figure out the concatenation. - */ - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - - /** Goes "up" a level in the hierarchy that the ModuleName represents. - * - * For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last - * element of the path. Other ModuleName representations may have other ways of doing this. - * - * @returns The parent ModuleName, if one exists. - * @returns std::nullopt if there is no parent for this module name. - */ - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; + virtual std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) + { + return std::nullopt; + } - virtual std::optional getHumanReadableModuleName_(const ModuleName& name) const + virtual std::string getHumanReadableModuleName(const ModuleName& name) const { return name; } - virtual std::optional getEnvironmentForModule(const ModuleName& name) const = 0; + virtual std::optional getEnvironmentForModule(const ModuleName& name) const + { + return std::nullopt; + } - /** LanguageService only: - * std::optional fromInstance(Instance* inst) - */ + // DEPRECATED APIS + // These are going to be removed with LuauNewRequireTracer + virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional fromAstFragment(AstExpr* expr) const = 0; + virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; + virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - std::optional getEnvironmentForModule(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 413b68f40..d08448351 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -90,10 +90,12 @@ struct Module TypeArena internalTypes; std::vector> scopes; // never empty - std::unordered_map astTypes; - std::unordered_map astExpectedTypes; - std::unordered_map astOriginalCallTypes; - std::unordered_map astOverloadResolvedTypes; + + DenseHashMap astTypes{nullptr}; + DenseHashMap astExpectedTypes{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; + DenseHashMap astOverloadResolvedTypes{nullptr}; + std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index a394a21b4..d892ccd7f 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -15,12 +15,6 @@ struct Module; using ModulePtr = std::shared_ptr; -struct ModuleInfo -{ - ModuleName name; - bool optional = false; -}; - struct ModuleResolver { virtual ~ModuleResolver() {} diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index e9778876c..c25545f57 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -17,12 +17,11 @@ struct AstLocal; struct RequireTraceResult { - DenseHashMap exprs{0}; - DenseHashMap optional{0}; + DenseHashMap exprs{nullptr}; std::vector> requires; }; -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h new file mode 100644 index 000000000..45338409f --- /dev/null +++ b/Analysis/include/Luau/Scope.h @@ -0,0 +1,67 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/TypeVar.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; + +using ScopePtr = std::shared_ptr; + +struct Binding +{ + TypeId typeId; + Location location; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional documentationSymbol; +}; + +struct Scope +{ + explicit Scope(TypePackId returnType); // root scope + explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. + + const ScopePtr parent; // null for the root + std::unordered_map bindings; + TypePackId returnType; + bool breakOk = false; + std::optional varargPack; + + TypeLevel level; + + std::unordered_map exportedTypeBindings; + std::unordered_map privateTypeBindings; + std::unordered_map typeAliasLocations; + + std::unordered_map> importedTypeBindings; + + std::optional lookup(const Symbol& name); + + std::optional lookupType(const Name& name); + std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + + std::unordered_map privateTypePackBindings; + std::optional lookupPack(const Name& name); + + // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); + + RefinementMap refinements; + + // For mutually recursive type aliases, it's important that + // they use the same types for the same names. + // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` + // we need that the generic type `T` in both cases is the same, so we use a cache. + std::unordered_map typeAliasTypeParameters; + std::unordered_map typeAliasTypePackParameters; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6ac868f76..80a14e8fb 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -52,8 +52,6 @@ // `T`, and the type of `f` are in the same SCC, which is why `f` gets // replaced. -LUAU_FASTFLAG(DebugLuauTrackOwningArena) - namespace Luau { @@ -188,20 +186,12 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(tv); } template TypePackId addTypePack(const T& tp) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ec2a1a26f..d701eb248 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -86,7 +86,10 @@ struct ApplyTypeFunction : Substitution { TypeLevel level; bool encounteredForwardedType; - std::unordered_map arguments; + std::unordered_map typeArguments; + std::unordered_map typePackArguments; + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId tp) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; @@ -328,7 +331,8 @@ struct TypeChecker TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location); + TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( @@ -398,54 +402,6 @@ struct TypeChecker int recursionCount = 0; }; -struct Binding -{ - TypeId typeId; - Location location; - bool deprecated = false; - std::string deprecatedSuggestion; - std::optional documentationSymbol; -}; - -struct Scope -{ - explicit Scope(TypePackId returnType); // root scope - explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. - - const ScopePtr parent; // null for the root - std::unordered_map bindings; - TypePackId returnType; - bool breakOk = false; - std::optional varargPack; - - TypeLevel level; - - std::unordered_map exportedTypeBindings; - std::unordered_map privateTypeBindings; - std::unordered_map typeAliasLocations; - - std::unordered_map> importedTypeBindings; - - std::optional lookup(const Symbol& name); - - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); - - std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); - - // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) - std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); - - RefinementMap refinements; - - // For mutually recursive type aliases, it's important that - // they use the same types for the same names. - // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` - // we need that the generic type `T` in both cases is the same, so we use a cache. - std::unordered_map typeAliasParameters; -}; - // Unit test hook void setPrintLine(void (*pl)(const std::string& s)); void resetPrintLine(); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 0d0adce7e..d987d46ca 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -117,7 +117,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -size_t size(const TypePackId tp); +size_t size(TypePackId tp); +bool finite(TypePackId tp); size_t size(const TypePack& tp); std::optional first(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 90a28b20d..d4e4e4913 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -228,6 +228,7 @@ struct TableTypeVar std::map methodDefinitionLocations; std::vector instantiatedTypeParams; + std::vector instantiatedTypePackParams; ModuleName definitionModuleName; std::optional boundTo; @@ -284,8 +285,9 @@ struct ClassTypeVar struct TypeFun { - /// These should all be generic + // These should all be generic std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -293,6 +295,20 @@ struct TypeFun * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. */ TypeId type; + + TypeFun() = default; + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } }; // Anything! All static checking is off. @@ -524,8 +540,4 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); -// TEMP: Clip this prototype with FFlag::LuauStringMetatable -std::optional> magicFunctionFormat( - struct TypeChecker& typechecker, const std::shared_ptr& scope, const AstExprCall& expr, ExprResult exprResult); - } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0ddc3cc0b..522914b2f 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -36,12 +36,17 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - std::shared_ptr counters; + UnifierCounters* counters; + UnifierCounters countersData; + + std::shared_ptr counters_DEPRECATED; + InternalErrorReporter* iceHandler; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters = nullptr); + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -58,11 +63,13 @@ struct Unifier void tryUnifyPrimitives(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void tryUnifyFreeTable(TypeId free, TypeId other); void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -80,9 +87,9 @@ struct Unifier public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -93,6 +100,9 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; }; } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index d3de1754a..0aed34c0a 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" @@ -143,8 +144,8 @@ std::optional findTypeAtPosition(const Module& module, const SourceModul { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - return it->second; + if (auto it = module.astTypes.find(expr)) + return *it; } return std::nullopt; @@ -154,8 +155,8 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) - return it->second; + if (auto it = module.astExpectedTypes.find(expr)) + return *it; } return std::nullopt; @@ -322,9 +323,9 @@ std::optional getDocumentationSymbolAtPosition(const Source TypeId matchingOverload = nullptr; if (parentExpr && parentExpr->is()) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end()) + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) { - matchingOverload = it->second; + matchingOverload = *it; } } @@ -345,9 +346,9 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (AstExprIndexName* indexName = targetExpr->as()) { - if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) + if (auto it = module.astTypes.find(indexName->expr)) { - TypeId parentTy = follow(it->second); + TypeId parentTy = follow(*it); if (const TableTypeVar* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dce92a0c9..235abf36f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -210,10 +210,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ return TypeCorrectKind::None; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return TypeCorrectKind::None; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (canUnify(expectedType, ty)) return TypeCorrectKind::Correct; @@ -682,10 +682,10 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n return std::nullopt; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return std::nullopt; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (const FunctionTypeVar* ftv = get(expectedType)) return true; @@ -784,9 +784,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (AstExprCall* exprCall = expr->as()) { - if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end()) + if (auto it = module.astTypes.find(exprCall->func)) { - if (const FunctionTypeVar* ftv = get(follow(it->second))) + if (const FunctionTypeVar* ftv = get(follow(*it))) { if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) inferredType = *ty; @@ -798,8 +798,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (tailPos != 0) break; - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - inferredType = it->second; + if (auto it = module.astTypes.find(expr)) + inferredType = *it; } if (inferredType) @@ -815,10 +815,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return nullptr; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); if (const FunctionTypeVar* ftv = get(ty)) return ftv; @@ -1129,9 +1129,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (node->is()) { - auto it = module.astTypes.find(node->asExpr()); - if (it != module.astTypes.end()) - autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result); + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) return; @@ -1203,13 +1202,13 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } - auto parentIter = module->astTypes.find(parentExpr); - if (parentIter == module->astTypes.end()) + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) { return std::nullopt; } - Luau::TypeId parentType = Luau::follow(parentIter->second); + Luau::TypeId parentType = Luau::follow(*parentIt); if (auto parentClass = Luau::get(parentType)) { @@ -1250,8 +1249,8 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - auto iter = module->astTypes.find(candidate->func); - if (iter == module->astTypes.end()) + auto it = module->astTypes.find(candidate->func); + if (!it) { return std::nullopt; } @@ -1267,7 +1266,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; }; - auto followedId = Luau::follow(iter->second); + auto followedId = Luau::follow(*it); if (auto functionType = Luau::get(followedId)) { return performCallback(functionType); @@ -1316,10 +1315,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto indexName = node->as()) { auto it = module->astTypes.find(indexName->expr); - if (it == module->astTypes.end()) + if (!it) return {}; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (isString(ty)) @@ -1447,9 +1446,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If item doesn't have a key, maybe the value is actually the key if (key ? key == node : node->is() && value == node) { - if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end()) + if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1485,9 +1484,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { - if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end()) + if (auto it = module->astTypes.find(idxExpr->expr)) { - return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 68ad5ac9f..3b0c21638 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -11,7 +11,7 @@ LUAU_FASTFLAG(LuauParseGenericFunctions) LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) +LUAU_FASTFLAG(LuauNewRequireTrace) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -218,7 +218,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ @@ -255,85 +254,18 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); - if (FFlag::LuauStringMetatable) - { - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); - LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); + std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); - TypeId stringLib = it->second.type; - addGlobalBinding(typeChecker, "string", stringLib, "@luau"); - } + addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringLibTy = getGlobalBinding(typeChecker, "string"); - TableTypeVar* stringLib = getMutable(stringLibTy); - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub"); - } - } - else + if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) { - if (!FFlag::LuauStringMetatable) - { - TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType}); - - TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})}); - - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}); - - TableTypeVar::Props stringLib = { - // FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied - {"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}}, - // FIXME char takes a variadic pack of numbers - {"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}}, - {"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(arena, stringType, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}}, - {"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(arena, stringType, {stringType, optionalString}, - {arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}}, - {"pack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(arena, stringType, {}, {numberType})}}, - {"unpack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - addGlobalBinding(typeChecker, "string", - arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - } - TableTypeVar::Props debugLib{ {"info", {makeIntersection(arena, { @@ -601,9 +533,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); - auto stringLib = getMutable(getGlobalBinding(typeChecker, "string")); - attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat); - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } @@ -791,11 +720,11 @@ static std::optional> magicFunctionRequire( return std::nullopt; } - AstExpr* require = expr.args.data[0]; - - if (!checkRequirePath(typechecker, require)) + if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; + const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 61a63f067..1e91561a6 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -206,27 +206,6 @@ std::string getBuiltinDefinitionSource() graphemes: (string, number?, number?) -> (() -> (number, number)), } - declare string: { - byte: (string, number?, number?) -> ...number, - char: (number, ...number) -> string, - find: (string, string, number?, boolean?) -> (number?, number?), - -- `string.format` has a magic function attached that will provide more type information for literal format strings. - format: (string, A...) -> string, - gmatch: (string, string) -> () -> (...string), - -- gsub is defined in C++ because we don't have syntax for describing a generic table. - len: (string) -> number, - lower: (string) -> string, - match: (string, string, number?) -> string?, - rep: (string, number) -> string, - reverse: (string) -> string, - sub: (string, number, number?) -> string, - upper: (string) -> string, - split: (string, string, string?) -> {string}, - pack: (string, A...) -> string, - packsize: (string) -> number, - unpack: (string, string, number?) -> R..., - } - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V )"; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 680bcf3f0..92fbffc80 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,9 +7,9 @@ #include -LUAU_FASTFLAG(LuauFasterStringifier) +LUAU_FASTFLAG(LuauTypeAliasPacks) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) +static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) { std::string s = "expects " + std::to_string(expectedCount) + " "; @@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo return s; } +static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) +{ + std::string s; + + if (FFlag::LuauTypeAliasPacks) + { + s = "expects "; + + if (isVariadic) + s += "at least "; + + s += std::to_string(expectedCount) + " "; + } + else + { + s = "expects " + std::to_string(expectedCount) + " "; + } + + if (argPrefix) + s += std::string(argPrefix) + " "; + + s += "argument"; + if (expectedCount != 1) + s += "s"; + + s += ", but "; + + if (actualCount == 0) + { + s += "none"; + } + else + { + if (actualCount < expectedCount) + s += "only "; + + s += std::to_string(actualCount); + } + + s += (actualCount == 1) ? " is" : " are"; + + s += " specified"; + + return s; +} + namespace Luau { @@ -128,7 +174,10 @@ struct ErrorConverter else return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + if (FFlag::LuauTypeAliasPacks) + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + else + return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -160,13 +209,16 @@ struct ErrorConverter std::string operator()(const Luau::UnknownRequire& e) const { - return "Unknown require: " + e.modulePath; + if (e.modulePath.empty()) + return "Unknown require: unsupported path"; + else + return "Unknown require: " + e.modulePath; } std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty()) + if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) { name += "<"; bool first = true; @@ -179,10 +231,37 @@ struct ErrorConverter name += toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : e.typeFun.typePackParams) + { + if (first) + first = false; + else + name += ", "; + + name += toString(t); + } + } + name += ">"; } - return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + if (FFlag::LuauTypeAliasPacks) + { + if (e.typeFun.typeParams.size() != e.actualParameters) + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); + } + else + { + return "Generic type '" + name + "' " + + wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + } } std::string operator()(const Luau::SyntaxError& e) const @@ -471,9 +550,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; + if (FFlag::LuauTypeAliasPacks) + { + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; + } + for (size_t i = 0; i < typeFun.typeParams.size(); ++i) + { if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) return false; + } + + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) + { + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; + } + } return true; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4d385ec19..b2529840b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/Scope.h" #include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" #include "Luau/Common.h" @@ -19,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) +LUAU_FASTFLAG(LuauNewRequireTrace) namespace Luau { @@ -69,6 +73,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) { + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); @@ -350,6 +356,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) CheckResult Frontend::check(const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto it = sourceNodes.find(name); @@ -479,6 +488,9 @@ CheckResult Frontend::check(const ModuleName& name) bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) { + LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); + // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search enum Mark { @@ -597,6 +609,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); @@ -608,6 +623,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend"); + const Config& config = configResolver->getConfig(""); SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); @@ -627,6 +644,9 @@ std::pair Frontend::lintFragment(std::string_view sour CheckResult Frontend::check(const SourceModule& module) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); Mode mode = module.mode.value_or(config.mode); @@ -648,6 +668,9 @@ CheckResult Frontend::check(const SourceModule& module) LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); LintOptions options = enabledLintWarnings.value_or(config.enabledLint); @@ -746,6 +769,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + auto it = sourceNodes.find(name); if (it != sourceNodes.end() && !it->second.dirty) { @@ -815,6 +841,9 @@ std::pair Frontend::getSourceNode(CheckResult& check */ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) { + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + SourceModule sourceModule; double timestamp = getTimestamp(); @@ -864,20 +893,11 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; - const ModuleName* relativeName = exprs.find(&pathExpr); - if (!relativeName || relativeName->empty()) + const ModuleInfo* info = exprs.find(&pathExpr); + if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) return std::nullopt; - if (FFlag::LuauTraceRequireLookupChild) - { - const bool* optional = it->second.optional.find(&pathExpr); - - return {{*relativeName, optional ? *optional : false}}; - } - else - { - return {{*relativeName, false}}; - } + return *info; } const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const @@ -891,12 +911,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - return frontend->fileResolver->moduleExists(moduleName); + if (FFlag::LuauNewRequireTrace) + return frontend->sourceNodes.count(moduleName) != 0; + else + return frontend->fileResolver->moduleExists(moduleName); } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const { - return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); + return frontend->fileResolver->getHumanReadableModuleName(moduleName); } ScopePtr Frontend::addEnvironment(const std::string& environmentName) diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 84e9b77f7..3b2671213 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,6 +2,8 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" +LUAU_FASTFLAG(LuauTypeAliasPacks) + namespace Luau { @@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty()) + if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) { stream << "<"; bool first = true; @@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : error.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); + } + } + stream << ">"; } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index a1018297e..064accba5 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -3,6 +3,9 @@ #include "Luau/Ast.h" #include "Luau/StringUtils.h" +#include "Luau/Common.h" + +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); + + if (FFlag::LuauTypeAliasPacks) + { + PROP(genericPacks); + } + PROP(type); PROP(exported); }); @@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTypeOrPack node) + { + if (node.type) + write(node.type); + else + write(node.typePack); + } + void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { if (node->hasPrefix) PROP(prefix); PROP(name); - PROP(generics); + PROP(parameters); }); } @@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(class AstTypePackExplicit* node) + { + writeNode(node, "AstTypePackExplicit", [&]() { + PROP(typeList); + }); + } + void write(class AstTypePackVariadic* node) { writeNode(node, "AstTypePackVariadic", [&]() { @@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstTypePackExplicit* node) override + { + write(node); + return false; + } + bool visit(class AstTypePackVariadic* node) override { write(node); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f97f6a4ad..bff947a56 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/Common.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) +LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) namespace Luau { @@ -85,10 +87,10 @@ struct LintContext return std::nullopt; auto it = module->astTypes.find(expr); - if (it == module->astTypes.end()) + if (!it) return std::nullopt; - return it->second; + return *it; } }; @@ -2144,6 +2146,19 @@ class LintTableOperations : AstVisitor "wrap it in parentheses to silence"); } + if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + { + // table.move(t, 0, _, _) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // table.move(t, _, _, 0) + else if (isConstant(args[3], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + } + return true; } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index f1d975fec..df6be767b 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -13,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -188,7 +190,7 @@ struct TypePackCloner template void defaultClone(const T& t) { - TypePackId cloned = dest.typePacks.allocate(t); + TypePackId cloned = dest.addTypePack(TypePackVar{t}); seenTypePacks[typePackId] = cloned; } @@ -197,7 +199,7 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); + seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); } void operator()(const Unifiable::Generic& t) @@ -219,13 +221,13 @@ struct TypePackCloner void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); seenTypePacks[typePackId] = cloned; } void operator()(const TypePack& t) { - TypePackId cloned = dest.typePacks.allocate(TypePack{}); + TypePackId cloned = dest.addTypePack(TypePack{}); TypePack* destTp = getMutable(cloned); LUAU_ASSERT(destTp != nullptr); seenTypePacks[typePackId] = cloned; @@ -241,7 +243,7 @@ struct TypePackCloner template void TypeCloner::defaultClone(const T& t) { - TypeId cloned = dest.typeVars.allocate(t); + TypeId cloned = dest.addType(t); seenTypes[typeId] = cloned; } @@ -250,7 +252,7 @@ void TypeCloner::operator()(const Unifiable::Free& t) if (encounteredFreeType) *encounteredFreeType = true; - seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{}); + seenTypes[typeId] = dest.addType(ErrorTypeVar{}); } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -275,7 +277,7 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const FunctionTypeVar& t) { - TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); FunctionTypeVar* ftv = getMutable(result); LUAU_ASSERT(ftv != nullptr); @@ -297,7 +299,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { - TypeId result = dest.typeVars.allocate(TableTypeVar{}); + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -323,7 +325,13 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); for (TypeId& arg : ttv->instantiatedTypeParams) - arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType)); + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + } if (ttv->state == TableState::Free) { @@ -343,7 +351,7 @@ void TypeCloner::operator()(const TableTypeVar& t) void TypeCloner::operator()(const MetatableTypeVar& t) { - TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); + TypeId result = dest.addType(MetatableTypeVar{}); MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; @@ -353,7 +361,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t) void TypeCloner::operator()(const ClassTypeVar& t) { - TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); ClassTypeVar* ctv = getMutable(result); seenTypes[typeId] = result; @@ -378,7 +386,7 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.typeVars.allocate(UnionTypeVar{}); + TypeId result = dest.addType(UnionTypeVar{}); seenTypes[typeId] = result; UnionTypeVar* option = getMutable(result); @@ -390,7 +398,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) void TypeCloner::operator()(const IntersectionTypeVar& t) { - TypeId result = dest.typeVars.allocate(IntersectionTypeVar{}); + TypeId result = dest.addType(IntersectionTypeVar{}); seenTypes[typeId] = result; IntersectionTypeVar* option = getMutable(result); @@ -451,8 +459,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) { TypeFun result; - for (TypeId param : typeFun.typeParams) - result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType)); + for (TypeId ty : typeFun.typeParams) + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 5b3997e2c..ad4d5ef43 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) namespace Luau { @@ -12,17 +13,18 @@ namespace Luau namespace { -struct RequireTracer : AstVisitor +struct RequireTracerOld : AstVisitor { - explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) + explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) : fileResolver(fileResolver) - , currentModuleName(std::move(currentModuleName)) + , currentModuleName(currentModuleName) { + LUAU_ASSERT(!FFlag::LuauNewRequireTrace); } FileResolver* const fileResolver; ModuleName currentModuleName; - DenseHashMap locals{0}; + DenseHashMap locals{nullptr}; RequireTraceResult result; std::optional fromAstFragment(AstExpr* expr) @@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor AstExpr* expr = stat->values.data[i]; expr->visit(this); - const ModuleName* name = result.exprs.find(expr); - if (name) - locals[local] = *name; + const ModuleInfo* info = result.exprs.find(expr); + if (info) + locals[local] = info->name; } } @@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor { std::optional name = fromAstFragment(global); if (name) - result.exprs[global] = *name; + result.exprs[global] = {*name}; return false; } @@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor { const ModuleName* name = locals.find(local->local); if (name) - result.exprs[local] = *name; + result.exprs[local] = {*name}; return false; } @@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor { indexName->expr->visit(this); - const ModuleName* name = result.exprs.find(indexName->expr); - if (name) + const ModuleInfo* info = result.exprs.find(indexName->expr); + if (info) { if (indexName->index == "parent" || indexName->index == "Parent") { - if (auto parent = fileResolver->getParentModuleName(*name)) - result.exprs[indexName] = *parent; + if (auto parent = fileResolver->getParentModuleName(info->name)) + result.exprs[indexName] = {*parent}; } else - result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value); + result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; } return false; @@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor { indexExpr->expr->visit(this); - const ModuleName* name = result.exprs.find(indexExpr->expr); + const ModuleInfo* info = result.exprs.find(indexExpr->expr); const AstExprConstantString* str = indexExpr->index->as(); - if (name && str) + if (info && str) { - result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size)); + result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; } indexExpr->index->visit(this); @@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor AstExprGlobal* globalName = call->func->as(); if (globalName && globalName->name == "require" && call->args.size >= 1) { - if (const ModuleName* moduleName = result.exprs.find(call->args.data[0])) - result.requires.push_back({*moduleName, call->location}); + if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) + result.requires.push_back({moduleInfo->name, call->location}); return false; } @@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor if (FFlag::LuauTraceRequireLookupChild && !rootName) { - if (const ModuleName* moduleName = result.exprs.find(indexName->expr)) - rootName = *moduleName; + if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) + rootName = moduleInfo->name; } if (!rootName) @@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor if (v.end() != std::find(v.begin(), v.end(), '/')) return false; - result.exprs[call] = fileResolver->concat(*rootName, v); + result.exprs[call] = {fileResolver->concat(*rootName, v)}; // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.optional[call] = true; + result.exprs[call].optional = true; return false; } }; +struct RequireTracer : AstVisitor +{ + RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + : result(result) + , fileResolver(fileResolver) + , currentModuleName(currentModuleName) + , locals(nullptr) + { + LUAU_ASSERT(FFlag::LuauNewRequireTrace); + } + + bool visit(AstExprTypeAssertion* expr) override + { + // suppress `require() :: any` + return false; + } + + bool visit(AstExprCall* expr) override + { + AstExprGlobal* global = expr->func->as(); + + if (global && global->name == "require" && expr->args.size >= 1) + requires.push_back(expr); + + return true; + } + + bool visit(AstStatLocal* stat) override + { + for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) + { + AstLocal* local = stat->vars.data[i]; + AstExpr* expr = stat->values.data[i]; + + // track initializing expression to be able to trace modules through locals + locals[local] = expr; + } + + return true; + } + + bool visit(AstStatAssign* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + { + // locals that are assigned don't have a known expression + if (AstExprLocal* expr = stat->vars.data[i]->as()) + locals[expr->local] = nullptr; + } + + return true; + } + + bool visit(AstType* node) override + { + // allow resolving require inside `typeof` annotations + return true; + } + + AstExpr* getDependent(AstExpr* node) + { + if (AstExprLocal* expr = node->as()) + return locals[expr->local]; + else if (AstExprIndexName* expr = node->as()) + return expr->expr; + else if (AstExprIndexExpr* expr = node->as()) + return expr->expr; + else if (AstExprCall* expr = node->as(); expr && expr->self) + return expr->func->as()->expr; + else + return nullptr; + } + + void process() + { + ModuleInfo moduleContext{currentModuleName}; + + // seed worklist with require arguments + work.reserve(requires.size()); + + for (AstExprCall* require: requires) + work.push_back(require->args.data[0]); + + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work.size(); ++i) + if (AstExpr* dep = getDependent(work[i])) + work.push_back(dep); + + // resolve all expressions to a module info + for (size_t i = work.size(); i > 0; --i) + { + AstExpr* expr = work[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstExpr* dep = getDependent(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + // locals just inherit their dependent context, no resolution required + if (expr->is()) + info = context ? std::optional(*context) : std::nullopt; + else + info = fileResolver->resolveModule(context, expr); + } + else + { + info = fileResolver->resolveModule(&moduleContext, expr); + } + + if (info) + result.exprs[expr] = std::move(*info); + } + + // resolve all requires according to their argument + result.requires.reserve(requires.size()); + + for (AstExprCall* require : requires) + { + AstExpr* arg = require->args.data[0]; + + if (const ModuleInfo* info = result.exprs.find(arg)) + { + result.requires.push_back({info->name, require->location}); + + ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! + result.exprs[require] = std::move(infoCopy); + } + else + { + result.exprs[require] = {}; // mark require as unresolved + } + } + } + + RequireTraceResult& result; + FileResolver* fileResolver; + ModuleName currentModuleName; + + DenseHashMap locals; + std::vector work; + std::vector requires; +}; + } // anonymous namespace -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName) +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - RequireTracer tracer{fileResolver, std::move(currentModuleName)}; - root->visit(&tracer); - return tracer.result; + if (FFlag::LuauNewRequireTrace) + { + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; + } + else + { + RequireTracerOld tracer{fileResolver, currentModuleName}; + root->visit(&tracer); + return tracer.result; + } } } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp new file mode 100644 index 000000000..c30db9c25 --- /dev/null +++ b/Analysis/src/Scope.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" + +namespace Luau +{ + +Scope::Scope(TypePackId returnType) + : parent(nullptr) + , returnType(returnType) + , level(TypeLevel()) +{ +} + +Scope::Scope(const ScopePtr& parent, int subLevel) + : parent(parent) + , returnType(parent->returnType) + , level(parent->level.incr()) +{ + level.subLevel = subLevel; +} + +std::optional Scope::lookup(const Symbol& name) +{ + Scope* scope = this; + + while (scope) + { + auto it = scope->bindings.find(name); + if (it != scope->bindings.end()) + return it->second.typeId; + + scope = scope->parent.get(); + } + + return std::nullopt; +} + +std::optional Scope::lookupType(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->exportedTypeBindings.find(name); + if (it != scope->exportedTypeBindings.end()) + return it->second; + + it = scope->privateTypeBindings.find(name); + if (it != scope->privateTypeBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +{ + const Scope* scope = this; + while (scope) + { + auto it = scope->importedTypeBindings.find(moduleAlias); + if (it == scope->importedTypeBindings.end()) + { + scope = scope->parent.get(); + continue; + } + + auto it2 = it->second.find(name); + if (it2 == it->second.end()) + { + scope = scope->parent.get(); + continue; + } + + return it2->second; + } + + return std::nullopt; +} + +std::optional Scope::lookupPack(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->privateTypePackBindings.find(name); + if (it != scope->privateTypePackBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) +{ + Scope* scope = this; + + while (scope) + { + for (const auto& [n, binding] : scope->bindings) + { + if (n.local && n.local->name == name.c_str()) + return binding; + else if (n.global.value && n.global == name.c_str()) + return binding; + } + + scope = scope->parent.get(); + + if (!traverseScopeChain) + break; + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 7223998a3..d861eb3da 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -6,9 +6,11 @@ #include #include -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -35,8 +37,15 @@ void Tarjan::visitChildren(TypeId ty, int index) visitChild(ttv->indexer->indexType); visitChild(ttv->indexer->indexResultType); } + for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); + } } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -332,9 +341,11 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; } @@ -350,9 +361,11 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; } @@ -382,6 +395,10 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) clone.tags = ttv->tags; result = addType(std::move(clone)); @@ -487,8 +504,15 @@ void Substitution::replaceChildren(TypeId ty) ttv->indexer->indexType = replace(ttv->indexer->indexType); ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); } + for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); + } } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9d2f47ba6..5651af7e9 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -9,10 +10,10 @@ #include #include -LUAU_FASTFLAG(LuauToStringFollowsBoundTo) LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -59,6 +60,13 @@ struct FindCyclicTypes { for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); + } + return exhaustive; } @@ -258,23 +266,60 @@ struct TypeVarStringifier void stringify(TypePackId tp); void stringify(TypePackId tpid, const std::vector>& names); - void stringify(const std::vector& types) + void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0) + if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) return; - if (types.size()) + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit("<"); - for (size_t i = 0; i < types.size(); ++i) + if (FFlag::LuauTypeAliasPacks) { - if (i > 0) - state.emit(", "); + bool first = true; + + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; + + stringify(ty); + } + + bool singleTp = typePacks.size() == 1; + + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; + + if (!first) + state.emit(", "); + else + first = false; + + if (!singleTp) + state.emit("("); + + stringify(tp); - stringify(types[i]); + if (!singleTp) + state.emit(")"); + } } + else + { + for (size_t i = 0; i < types.size(); ++i) + { + if (i > 0) + state.emit(", "); - if (types.size()) + stringify(types[i]); + } + } + + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit(">"); } @@ -388,7 +433,7 @@ struct TypeVarStringifier void operator()(TypeId, const TableTypeVar& ttv) { - if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo) + if (ttv.boundTo) return stringify(*ttv.boundTo); if (!state.exhaustive) @@ -411,14 +456,14 @@ struct TypeVarStringifier } state.emit(*ttv.name); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } if (ttv.syntheticName) { state.result.invalid = true; state.emit(*ttv.syntheticName); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } } @@ -900,13 +945,26 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) return result; std::vector params; for (TypeId tp : ttv->instantiatedTypeParams) params.push_back(toString(tp)); + if (FFlag::LuauTypeAliasPacks) + { + // Doesn't preserve grouping of multiple type packs + // But this is under a parent block of code that is being removed later + for (TypePackId tp : ttv->instantiatedTypePackParams) + { + std::string content = toString(tp); + + if (!content.empty()) + params.push_back(std::move(content)); + } + } + result.name += "<" + join(params, ", ") + ">"; return result; } @@ -950,30 +1008,37 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) + if (FFlag::LuauTypeAliasPacks) { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); } else { - result.name += ">"; + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) + return result; + + result.name += "<"; + + bool first = true; + for (TypeId ty : ttv->instantiatedTypeParams) + { + if (!first) + result.name += ", "; + else + first = false; + + tvs.stringify(ty); + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + { + result.truncated = true; + result.name += "... "; + } + else + { + result.name += ">"; + } } return result; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 462c70ffc..1b83ccdc2 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { @@ -280,10 +281,19 @@ struct Printer void visualizeTypePackAnnotation(const AstTypePack& annotation) { - if (const AstTypePackVariadic* variadic = annotation.as()) + if (const AstTypePackVariadic* variadicTp = annotation.as()) { writer.symbol("..."); - visualizeTypeAnnotation(*variadic->variadicType); + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + visualizeTypeList(explicitTp->typeList, true); } else { @@ -807,7 +817,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0) + if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -817,6 +827,17 @@ struct Printer comma(); writer.identifier(o.value); } + + if (FFlag::LuauTypeAliasPacks) + { + for (auto o : a->genericPacks) + { + comma(); + writer.identifier(o.value); + writer.symbol("..."); + } + } + writer.symbol(">"); } writer.maybeSpace(a->type->location.begin, 2); @@ -960,15 +981,20 @@ struct Printer if (const auto& a = typeAnnotation.as()) { writer.write(a->name.value); - if (a->generics.size > 0) + if (a->parameters.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); - for (auto o : a->generics) + for (auto o : a->parameters) { comma(); - visualizeTypeAnnotation(*o); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack); } + writer.symbol(">"); } } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 17c57c848..266c19865 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { @@ -33,7 +35,6 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data namespace Luau { - class TypeRehydrationVisitor { mutable std::map seen; @@ -57,6 +58,8 @@ class TypeRehydrationVisitor { } + AstTypePack* rehydrate(TypePackId tp) const; + AstType* operator()(const PrimitiveTypeVar& ptv) const { switch (ptv.type) @@ -85,16 +88,24 @@ class TypeRehydrationVisitor if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) { - AstArray generics; - generics.size = ttv.instantiatedTypeParams.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstType*) * generics.size)); + AstArray parameters; + parameters.size = ttv.instantiatedTypeParams.size(); + parameters.data = static_cast(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size)); for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) { - generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); + parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; + } + + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) + { + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; + } } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); } if (hasSeen(&ttv)) @@ -222,10 +233,17 @@ class TypeRehydrationVisitor AstTypePack* argTailAnnotation = nullptr; if (argTail) { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + argTailAnnotation = rehydrate(*argTail); + } + else + { + TypePackId tail = *argTail; + if (const VariadicTypePack* vtp = get(tail)) + { + argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -255,10 +273,17 @@ class TypeRehydrationVisitor AstTypePack* retTailAnnotation = nullptr; if (retTail) { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) + { + retTailAnnotation = rehydrate(*retTail); + } + else { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + TypePackId tail = *retTail; + if (const VariadicTypePack* vtp = get(tail)) + { + retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -313,6 +338,68 @@ class TypeRehydrationVisitor const TypeRehydrationOptions& options; }; +class TypePackRehydrationVisitor +{ +public: + TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + : allocator(allocator) + , typeVisitor(typeVisitor) + { + } + + AstTypePack* operator()(const BoundTypePack& btp) const + { + return Luau::visit(*this, btp.boundTo->ty); + } + + AstTypePack* operator()(const TypePack& tp) const + { + AstArray head; + head.size = tp.head.size(); + head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); + + for (size_t i = 0; i < tp.head.size(); i++) + head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + + AstTypePack* tail = nullptr; + + if (tp.tail) + tail = Luau::visit(*this, (*tp.tail)->ty); + + return allocator->alloc(Location(), AstTypeList{head, tail}); + } + + AstTypePack* operator()(const VariadicTypePack& vtp) const + { + return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + } + + AstTypePack* operator()(const GenericTypePack& gtp) const + { + return allocator->alloc(Location(), AstName(gtp.name.c_str())); + } + + AstTypePack* operator()(const FreeTypePack& gtp) const + { + return allocator->alloc(Location(), AstName("free")); + } + + AstTypePack* operator()(const Unifiable::Error&) const + { + return allocator->alloc(Location(), AstName("Unifiable")); + } + +private: + Allocator* allocator; + const TypeRehydrationVisitor& typeVisitor; +}; + +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +{ + TypePackRehydrationVisitor tprv(allocator, *this); + return Luau::visit(tprv, tp->ty); +} + class TypeAttacher : public AstVisitor { public: @@ -406,9 +493,16 @@ class TypeAttacher : public AstVisitor if (tail) { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + if (FFlag::LuauTypeAliasPacks) + { + variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + } + else + { + TypePackId tailPack = *tail; + if (const VariadicTypePack* vtp = get(tailPack)) + variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + } } fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 2216881b7..3a1fdfff5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5,21 +5,22 @@ #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/TimeTrace.h" #include #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0) -LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) @@ -27,14 +28,11 @@ LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) -LUAU_FASTFLAG(DebugLuauTrackOwningArena) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false) LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) @@ -45,6 +43,10 @@ LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) +LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -216,9 +218,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) - , booleanType( - FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))) - , threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread))) + , booleanType(singletonTypes.booleanType) + , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) @@ -237,6 +238,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { + LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + currentModule.reset(new Module()); currentModule->type = module.type; @@ -1177,44 +1181,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + + if (FFlag::LuauTypeAliasPacks) + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + else + bindingsMap[name] = TypeFun{binding->typeParams, errorType}; } else { ScopePtr aliasScope = childScope(scope, typealias.location); - std::vector generics; - for (AstName generic : typealias.generics) + if (FFlag::LuauTypeAliasPacks) { - Name n = generic.value; + auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; + } + else + { + std::vector generics; + for (AstName generic : typealias.generics) { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } + Name n = generic.value; - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; + // These generics are the only thing that will ever be added to aliasScope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); + } + + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction) + { + TypeId& cached = scope->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{aliasScope->level, n}); + g = cached; + } + else + g = addType(GenericTypeVar{aliasScope->level, n}); + generics.push_back(g); + aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; - } - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), ty}; + } } } else @@ -1231,6 +1252,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : binding->typePackParams) + { + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; + } + } + TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); if (auto ttv = getMutable(follow(ty))) { @@ -1238,7 +1269,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (ttv->name) { // Copy can be skipped if this is an identical alias - if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams) + if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || + (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1249,6 +1281,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.name = name; clone.instantiatedTypeParams = binding->typeParams; + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = binding->typePackParams; + ty = addType(std::move(clone)); } } @@ -1256,6 +1291,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1280,7 +1318,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; if (FFlag::LuauAddMissingFollow) @@ -1465,7 +1503,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FFlag::LuauStoreMatchingOverloadFnType) { - currentModule->astTypes.try_emplace(&expr, result.type); + if (!currentModule->astTypes.find(&expr)) + currentModule->astTypes[&expr] = result.type; } else { @@ -2193,7 +2232,7 @@ TypeId TypeChecker::checkRelationalOperation( * have a better, more descriptive error teed up. */ Unifier state = mkUnifier(expr.location); - if (!FFlag::LuauEqConstraint || !isEquality) + if (!isEquality) state.tryUnify(lhsType, rhsType); bool needsMetamethod = !isEquality; @@ -2262,7 +2301,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality)) + if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2276,18 +2315,6 @@ TypeId TypeChecker::checkRelationalOperation( return errorType; } - if (!FFlag::LuauEqConstraint) - { - if (isEquality) - { - ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location); - if (!state.errors.empty() && !errVec.empty()) - reportError(expr.location, TypeMismatch{lhsType, rhsType}); - } - else - reportErrors(state.errors); - } - return booleanType; } @@ -2443,7 +2470,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)) + else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; @@ -2466,14 +2493,6 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - // Once we have EqPredicate, we should break this else branch into its' own branch. - // For now, fall through is intentional. - if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - ExprResult lhs = checkExpr(scope, *expr.left); ExprResult rhs = checkExpr(scope, *expr.right); @@ -2755,12 +2774,6 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } - else if (FFlag::LuauIndexTablesWithIndexers) - { - // We allow t[x] where x:string for tables without an indexer - unify(indexType, stringType, expr.location); - return std::pair(anyType, nullptr); - } else { TypeId resultType = freshType(scope); @@ -3076,6 +3089,13 @@ static Location getEndLocation(const AstExprFunction& function) void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkFunctionBody", "TypeChecker"); + + if (function.debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", function.debugname.value); + else + LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str()); + if (FunctionTypeVar* funTy = getMutable(ty)) { check(scope, *function.body); @@ -3885,6 +3905,20 @@ std::optional TypeChecker::matchRequire(const AstExprCall& call) TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); + + if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + { + if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + { + reportError(TypeError{location, UnknownRequire{}}); + return errorType; + } + + return anyType; + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4472,7 +4506,7 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level))); + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) @@ -4482,11 +4516,7 @@ TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneri TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric))); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4506,20 +4536,12 @@ TypeId TypeChecker::addType(const UnionTypeVar& utv) TypeId TypeChecker::addTV(TypeVar&& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePackVar&& tv) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePack&& tp) @@ -4578,7 +4600,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print") { - if (lit->generics.size != 1) + if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); return addType(ErrorTypeVar{}); @@ -4588,7 +4610,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation opts.exhaustive = true; opts.maxTableLength = 0; - TypeId param = resolveType(scope, *lit->generics.data[0]); + TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); return param; } @@ -4614,18 +4636,86 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(ErrorTypeVar{}); } - if (lit->generics.size == 0 && tf->typeParams.empty()) + if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) + { return tf->type; - else if (lit->generics.size != tf->typeParams.size()) + } + else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}}); + reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); return addType(ErrorTypeVar{}); } + else if (FFlag::LuauTypeAliasPacks) + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + return addType(ErrorTypeVar{}); + } + + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; + + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) + { + TypeId ty = resolveType(scope, *type); + + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); + } + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) + { + TypePackId tp = resolveTypePack(scope, *typePack); + + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); + else + typePackParams.push_back(tp); + } + } + + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); + + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + return addType(ErrorTypeVar{}); + } + + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; + } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); + } else { std::vector typeParams; - for (AstType* paramAnnot : lit->generics) - typeParams.push_back(resolveType(scope, *paramAnnot)); + + for (const auto& param : lit->parameters) + typeParams.push_back(resolveType(scope, *param.type)); if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { @@ -4634,7 +4724,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return tf->type; } - return instantiateTypeFun(scope, *tf, typeParams, annotation.location); + return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); } } else if (const auto& table = annotation.as()) @@ -4765,6 +4855,18 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack return *genericTy; } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + std::vector types; + + for (auto type : explicitTp->typeList.types) + types.push_back(resolveType(scope, *type)); + + if (auto tailType = explicitTp->typeList.tailType) + return addTypePack(types, resolveTypePack(scope, *tailType)); + + return addTypePack(types); + } else { ice("Unknown AstTypePack kind"); @@ -4799,12 +4901,28 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) return false; } +bool ApplyTypeFunction::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypePackId tp) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + return true; + else + return false; +} + TypeId ApplyTypeFunction::clean(TypeId ty) { // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - TypeId& arg = arguments[ty]; + TypeId& arg = typeArguments[ty]; if (arg) return arg; else @@ -4816,17 +4934,37 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - return addTypePack(FreeTypePack{level}); + if (FFlag::LuauTypeAliasPacks) + { + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; + else + return addTypePack(FreeTypePack{level}); + } + else + { + return addTypePack(FreeTypePack{level}); + } } -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location) +TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty()) + if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) return tf.type; - applyTypeFunction.arguments.clear(); + applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + + if (FFlag::LuauTypeAliasPacks) + { + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + } + applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; applyTypeFunction.encounteredForwardedType = false; @@ -4875,6 +5013,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4890,6 +5031,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4899,6 +5043,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, std::pair, std::vector> TypeChecker::createGenericTypes( const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { + LUAU_ASSERT(scope->parent); + std::vector generics; for (const AstName& generic : genericNames) { @@ -4912,7 +5058,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypeId g = addType(Unifiable::Generic{scope->level, n}); + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypeId& cached = scope->parent->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{scope->level, n}); + g = cached; + } + else + { + g = addType(Unifiable::Generic{scope->level, n}); + } + generics.push_back(g); scope->privateTypeBindings[n] = TypeFun{{}, g}; } @@ -4930,7 +5088,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + TypePackId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = cached; + } + else + { + g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + } + genericPacks.push_back(g); scope->privateTypePackBindings[n] = g; } @@ -5013,13 +5183,8 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme else if (auto isaP = get(predicate)) resolve(*isaP, errVec, refis, scope, sense); else if (auto typeguardP = get(predicate)) - { - if (FFlag::LuauImprovedTypeGuardPredicate2) - resolve(*typeguardP, errVec, refis, scope, sense); - else - DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense); - } - else if (auto eqP = get(predicate); eqP && FFlag::LuauEqConstraint) + resolve(*typeguardP, errVec, refis, scope, sense); + else if (auto eqP = get(predicate)) resolve(*eqP, errVec, refis, scope, sense); else ice("Unhandled predicate kind"); @@ -5145,7 +5310,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } } - else if (FFlag::LuauImprovedTypeGuardPredicate2) + else { auto lctv = get(option); auto rctv = get(isaP.ty); @@ -5159,19 +5324,6 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement if (canUnify(option, isaP.ty, isaP.location).empty() == sense) return isaP.ty; } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (lctv && rctv) - { - if (isSubclass(lctv, rctv) == sense) - return option; - else if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - } - } return std::nullopt; }; @@ -5266,7 +5418,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty()) + if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); @@ -5292,7 +5444,8 @@ void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, Error "userdata", // no op. Requires special handling. }; - if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty()) + if (auto typeFun = globalScope->lookupType(typeguardP.kind); + typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) { if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) addRefinement(refis, typeguardP.lvalue, typeFun->type); @@ -5319,38 +5472,41 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return; } - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; + if (FFlag::LuauEqConstraint) + { + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + { + addRefinement(refis, eqP.lvalue, eqP.type); + return; + } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + std::unordered_set set; + for (TypeId left : lhs) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } } - } - if (set.empty()) - return; + if (set.empty()) + return; - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); + } } bool TypeChecker::isNonstrictMode() const @@ -5379,119 +5535,4 @@ std::vector> TypeChecker::getScopes() const return currentModule->scopes; } -Scope::Scope(TypePackId returnType) - : parent(nullptr) - , returnType(returnType) - , level(TypeLevel()) -{ -} - -Scope::Scope(const ScopePtr& parent, int subLevel) - : parent(parent) - , returnType(parent->returnType) - , level(parent->level.incr()) -{ - level.subLevel = subLevel; -} - -std::optional Scope::lookup(const Symbol& name) -{ - Scope* scope = this; - - while (scope) - { - auto it = scope->bindings.find(name); - if (it != scope->bindings.end()) - return it->second.typeId; - - scope = scope->parent.get(); - } - - return std::nullopt; -} - -std::optional Scope::lookupType(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->exportedTypeBindings.find(name); - if (it != scope->exportedTypeBindings.end()) - return it->second; - - it = scope->privateTypeBindings.find(name); - if (it != scope->privateTypeBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) -{ - const Scope* scope = this; - while (scope) - { - auto it = scope->importedTypeBindings.find(moduleAlias); - if (it == scope->importedTypeBindings.end()) - { - scope = scope->parent.get(); - continue; - } - - auto it2 = it->second.find(name); - if (it2 == it->second.end()) - { - scope = scope->parent.get(); - continue; - } - - return it2->second; - } - - return std::nullopt; -} - -std::optional Scope::lookupPack(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->privateTypePackBindings.find(name); - if (it != scope->privateTypePackBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) -{ - Scope* scope = this; - - while (scope) - { - for (const auto& [n, binding] : scope->bindings) - { - if (n.local && n.local->name == name.c_str()) - return binding; - else if (n.global.value && n.global == name.c_str()) - return binding; - } - - scope = scope->parent.get(); - - if (!traverseScopeChain) - break; - } - - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5970f304d..68a16ef04 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -209,6 +209,19 @@ size_t size(TypePackId tp) return 0; } +bool finite(TypePackId tp) +{ + tp = follow(tp); + + if (auto pack = get(tp)) + return pack->tail ? finite(*pack->tail) : true; + + if (auto pack = get(tp)) + return false; + + return true; +} + size_t size(const TypePack& tp) { size_t result = tp.head.size(); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index b9f509788..0d9d91e0c 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,11 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAG(LuauStringMetatable) - namespace Luau { @@ -13,21 +12,6 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto it = globalScope->bindings.find(AstName{"string"}); - if (it != globalScope->bindings.end()) - return it->second.typeId; - else - return std::nullopt; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 111f4f53d..e963fc74e 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,11 +19,9 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) -LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -193,27 +191,11 @@ bool isOptional(TypeId ty) bool isTableIntersection(TypeId ty) { - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - if (!get(follow(ty))) - return false; - - std::vector parts = flattenIntersection(ty); - return std::all_of(parts.begin(), parts.end(), getTableType); - } - else - { - if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId part : itv->parts) - { - if (getTableType(follow(part))) - return true; - } - } - + if (!get(follow(ty))) return false; - } + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), getTableType); } bool isOverloadedFunction(TypeId ty) @@ -236,7 +218,7 @@ std::optional getMetatable(TypeId type) else if (const ClassTypeVar* classType = get(type)) return classType->metatable; else if (const PrimitiveTypeVar* primitiveType = get(type); - FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable) + primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -871,6 +853,12 @@ void StateDot::visitChildren(TypeId ty, int index) } for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp, index, "typeParam"); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } } else if (const MetatableTypeVar* mtv = get(ty)) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 89c3f80ca..117cbc289 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -3,23 +3,25 @@ #include "Luau/Common.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/TimeTrace.h" #include LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0); -LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) namespace Luau { @@ -43,21 +45,23 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(std::make_shared()) + , counters(&countersData) + , counters_DEPRECATED(std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters) + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(seen) , location(location) , variance(variance) - , counters(counters ? counters : std::make_shared()) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); @@ -65,16 +69,26 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::v void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -440,7 +454,11 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTp, subTp, isFunctionCall); } @@ -450,10 +468,16 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -762,9 +786,210 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - std::unique_ptr resetter; + if (!FFlag::LuauTableSubtypingVariance) + return DEPRECATED_tryUnifyTables(left, right, isIntersection); - resetter.reset(new Resetter{&variance}); + TableTypeVar* lt = getMutable(left); + TableTypeVar* rt = getMutable(right); + if (!lt || !rt) + ice("passed non-table types to unifyTables"); + + std::vector missingProperties; + std::vector extraProperties; + + // Reminder: left is the supertype, right is the subtype. + // Width subtyping: any property in the supertype must be in the subtype, + // and the types must agree. + for (const auto& [name, prop] : lt->props) + { + const auto& r = rt->props.find(name); + if (r != rt->props.end()) + { + // TODO: read-only properties don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, r->second.type); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (rt->indexer && isString(rt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (rt->state == TableState::Free) + { + log(rt); + rt->props[name] = prop; + } + else + missingProperties.push_back(name); + } + + for (const auto& [name, prop] : rt->props) + { + if (lt->props.count(name)) + { + // If both lt and rt contain the property, then + // we're done since we already unified them above + } + else if (lt->indexer && isString(lt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->state == TableState::Unsealed) + { + // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. + // TODO: file a JIRA + // TODO: hopefully readonly/writeonly properties will fix this. + Property clone = prop; + clone.type = deeplyOptional(clone.type); + log(lt); + lt->props[name] = clone; + } + else if (variance == Covariant) + {} + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (lt->state == TableState::Free) + { + log(lt); + lt->props[name] = prop; + } + else + extraProperties.push_back(name); + } + + // Unify indexers + if (lt->indexer && rt->indexer) + { + // TODO: read-only indexers don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(*lt->indexer, *rt->indexer); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->indexer) + { + if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + { + // passing/assigning a table without an indexer to something that has one + // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. + // TODO: we only need to do this if the supertype's indexer is read/write + // since that can add indexed elements. + log(rt); + rt->indexer = lt->indexer; + } + } + else if (rt->indexer && variance == Invariant) + { + // Symmetric if we are invariant + if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + { + log(lt); + lt->indexer = rt->indexer; + } + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + + /* + * TypeVars are commonly cyclic, so it is entirely possible + * for unifying a property of a table to change the table itself! + * We need to check for this and start over if we notice this occurring. + * + * I believe this is guaranteed to terminate eventually because this will + * only happen when a free table is bound to another table. + */ + if (lt->boundTo || rt->boundTo) + return tryUnify_(left, right); + + if (lt->state == TableState::Free) + { + log(lt); + lt->boundTo = right; + } + else if (rt->state == TableState::Free) + { + log(rt); + rt->boundTo = left; + } +} + +TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) +{ + ty = follow(ty); + if (get(ty)) + return ty; + else if (isOptional(ty)) + return ty; + else if (const TableTypeVar* ttv = get(ty)) + { + TypeId& result = seen[ty]; + if (result) + return result; + result = types->addType(*ttv); + TableTypeVar* resultTtv = getMutable(result); + for (auto& [name, prop] : resultTtv->props) + prop.type = deeplyOptional(prop.type, seen); + return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + } + else + return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); +} + +void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +{ + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + Resetter resetter{&variance}; variance = Invariant; TableTypeVar* lt = getMutable(left); @@ -894,10 +1119,7 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) if (!freeTable->boundTo && otherTable->state != TableState::Free) { - if (FFlag::LuauLogTableTypeVarBoundTo) - log(freeTable); - else - log(freeTypeId); + log(freeTable); freeTable->boundTo = otherTypeId; } } @@ -1196,9 +1418,11 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack( +static void queueTypePack_DEPRECATED( std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + while (true) { if (FFlag::LuauAddMissingFollow) @@ -1244,6 +1468,55 @@ static void queueTypePack( } } +static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (true) + { + if (FFlag::LuauAddMissingFollow) + a = follow(a); + + if (seenTypePacks.find(a)) + break; + seenTypePacks.insert(a); + + if (FFlag::LuauAddMissingFollow) + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + else + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + + if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + } +} + void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) { const VariadicTypePack* lv = get(superTp); @@ -1297,9 +1570,11 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny( +static void tryUnifyWithAny_DEPRECATED( std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + std::unordered_set seen; while (!queue.empty()) @@ -1310,6 +1585,59 @@ static void tryUnifyWithAny( continue; seen.insert(ty); + if (get(ty)) + { + state.log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } +} + +static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, + TypeId anyType, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (!queue.empty()) + { + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + if (get(ty)) { state.log(ty); @@ -1354,14 +1682,33 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); + if (FFlag::LuauTypecheckOpts) + { + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; + } + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue = {ty}; + + tempSeenTy.clear(); + tempSeenTp.clear(); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue = {ty}; - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + } } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1370,12 +1717,26 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorType; - std::unordered_set seenTypePacks; - std::vector queue; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue; - queueTypePack(queue, seenTypePacks, *this, ty, any); + tempSeenTy.clear(); + tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any); + queueTypePack(queue, tempSeenTp, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue; + + queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); + + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); + } } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -1387,21 +1748,6 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto found = globalScope->bindings.find(AstName{"string"}); - if (found == globalScope->bindings.end()) - return std::nullopt; - else - return found->second.typeId; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; @@ -1427,21 +1773,36 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTy.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (seen.end() != seen.find(haystack)) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; + + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; - seen.insert(haystack); + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1458,7 +1819,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI } auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); + occursCheck(seen_DEPRECATED, seen, needle, tv); }; if (get(haystack)) @@ -1488,19 +1849,33 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTp.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (seen.find(haystack) != seen.end()) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; + + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1508,7 +1883,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -1528,8 +1904,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl { if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); + occursCheck(seen_DEPRECATED, seen, needle, f->retType); } } } @@ -1546,7 +1922,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters}; + return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index df38cfecf..a2189f7b7 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -264,6 +264,10 @@ class AstVisitor { return false; } + virtual bool visit(class AstTypePackExplicit* node) + { + return visit((class AstTypePack*)node); + } virtual bool visit(class AstTypePackVariadic* node) { return visit((class AstTypePack*)node); @@ -930,12 +934,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, + AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -1007,19 +1013,28 @@ class AstType : public AstNode } }; +// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use +struct AstTypeOrPack +{ + AstType* type = nullptr; + AstTypePack* typePack = nullptr; +}; + class AstTypeReference : public AstType { public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics = {}); + AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, + const AstArray& parameters = {}); void visit(AstVisitor* visitor) override; bool hasPrefix; + bool hasParameterList; AstName prefix; AstName name; - AstArray generics; + AstArray parameters; }; struct AstTableProp @@ -1152,6 +1167,18 @@ class AstTypePack : public AstNode } }; +class AstTypePackExplicit : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackExplicit) + + AstTypePackExplicit(const Location& location, AstTypeList typeList); + + void visit(AstVisitor* visitor) override; + + AstTypeList typeList; +}; + class AstTypePackVariadic : public AstTypePack { public: diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 02924e883..a7b2515a6 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -136,7 +136,10 @@ class DenseHashTable const Key& key = ItemInterface::getKey(data[i]); if (!eq(key, empty_key)) - *newtable.insert_unsafe(key) = data[i]; + { + Item* item = newtable.insert_unsafe(key); + *item = std::move(data[i]); + } } LUAU_ASSERT(count == newtable.count); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e6ebd503c..42c64dc92 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -218,13 +218,14 @@ class Parser AstTableIndexer* parseTableIndexerAnnotation(); - AstType* parseFunctionTypeAnnotation(); + AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); - AstType* parseSimpleTypeAnnotation(); + AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstTypeOrPack parseTypeOrPackAnnotation(); AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); AstType* parseTypeAnnotation(); @@ -284,7 +285,7 @@ class Parser std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams(); AstExpr* parseString(); @@ -413,6 +414,7 @@ class Parser std::vector scratchLocal; std::vector scratchTableTypeProps; std::vector scratchAnnotation; + std::vector scratchTypeOrPackAnnotation; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h new file mode 100644 index 000000000..641dfd3c3 --- /dev/null +++ b/Ast/include/Luau/TimeTrace.h @@ -0,0 +1,223 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Common.h" + +#include + +#include + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +uint32_t getClockMicroseconds(); + +struct Token +{ + const char* name; + const char* category; +}; + +enum class EventType : uint8_t +{ + Enter, + Leave, + + ArgName, + ArgValue, +}; + +struct Event +{ + EventType type; + uint16_t token; + + union + { + uint32_t microsec; // 1 hour trace limit + uint32_t dataPos; + } data; +}; + +struct GlobalContext; +struct ThreadContext; + +GlobalContext& getGlobalContext(); + +uint16_t createToken(GlobalContext& context, const char* name, const char* category); +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); +void releaseThread(GlobalContext& context, ThreadContext* threadContext); +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data); + +struct ThreadContext +{ + ThreadContext() + : globalContext(getGlobalContext()) + { + threadId = createThread(globalContext, this); + } + + ~ThreadContext() + { + if (!events.empty()) + flushEvents(); + + releaseThread(globalContext, this); + } + + void flushEvents() + { + static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + + events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); + + TimeTrace::flushEvents(globalContext, threadId, events, data); + + events.clear(); + data.clear(); + + events.push_back({EventType::Leave, 0, {getClockMicroseconds()}}); + } + + void eventEnter(uint16_t token) + { + eventEnter(token, getClockMicroseconds()); + } + + void eventEnter(uint16_t token, uint32_t microsec) + { + events.push_back({EventType::Enter, token, {microsec}}); + } + + void eventLeave() + { + eventLeave(getClockMicroseconds()); + } + + void eventLeave(uint32_t microsec) + { + events.push_back({EventType::Leave, 0, {microsec}}); + + if (events.size() > kEventFlushLimit) + flushEvents(); + } + + void eventArgument(const char* name, const char* value) + { + uint32_t pos = uint32_t(data.size()); + data.insert(data.end(), name, name + strlen(name) + 1); + events.push_back({EventType::ArgName, 0, {pos}}); + + pos = uint32_t(data.size()); + data.insert(data.end(), value, value + strlen(value) + 1); + events.push_back({EventType::ArgValue, 0, {pos}}); + } + + GlobalContext& globalContext; + uint32_t threadId; + std::vector events; + std::vector data; + + static constexpr size_t kEventFlushLimit = 8192; +}; + +ThreadContext& getThreadContext(); + +struct Scope +{ + explicit Scope(ThreadContext& context, uint16_t token) + : context(context) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventEnter(token); + } + + ~Scope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventLeave(); + } + + ThreadContext& context; +}; + +struct OptionalTailScope +{ + explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) + : context(context) + , token(token) + , threshold(threshold) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + pos = uint32_t(context.events.size()); + microsec = getClockMicroseconds(); + } + + ~OptionalTailScope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + if (pos == context.events.size()) + { + uint32_t curr = getClockMicroseconds(); + + if (curr - microsec > threshold) + { + context.eventEnter(token, microsec); + context.eventLeave(curr); + } + } + } + + ThreadContext& context; + uint16_t token; + uint32_t threshold; + uint32_t microsec; + uint32_t pos; +}; + +LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); + +} // namespace TimeTrace +} // namespace Luau + +// Regular scope +#define LUAU_TIMETRACE_SCOPE(name, category) \ + static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + +// A scope without nested scopes that may be skipped if the time it took is less than the threshold +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ + static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + +// Extra key/value data can be added to regular scopes +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ + lttScopeStatic.second.eventArgument(name, value); \ + } while (false) + +#else + +#define LUAU_TIMETRACE_SCOPE(name, category) +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + } while (false) + +#endif diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index fff1537dd..b1209faa1 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) + , genericPacks(genericPacks) , type(type) , exported(exported) { @@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics) +AstTypeReference::AstTypeReference( + const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) , hasPrefix(bool(prefix)) + , hasParameterList(hasParameterList) , prefix(prefix ? *prefix : AstName()) , name(name) - , generics(generics) + , parameters(parameters) { } @@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (AstType* generic : generics) - generic->visit(visitor); + for (const AstTypeOrPack& param : parameters) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } } } @@ -849,6 +858,24 @@ void AstTypeError::visit(AstVisitor* visitor) } } +AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList) + : AstTypePack(ClassIndex(), location) + , typeList(typeList) +{ +} + +void AstTypePackExplicit::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : typeList.types) + type->visit(visitor); + + if (typeList.tailType) + typeList.tailType->visit(visitor); + } +} + AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType) : AstTypePack(ClassIndex(), location) , variadicType(variadicType) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6672efe8d..40026d8be 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + #include // Warning: If you are introducing new syntax, ensure that it is behind a separate @@ -13,6 +15,8 @@ LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) namespace Luau { @@ -148,6 +152,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer) ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) { + LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); + Parser p(buffer, bufferSize, names, allocator); try @@ -769,14 +775,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - // TODO: support generic type pack parameters in type aliases CLI-39907 auto [generics, genericPacks] = parseGenericTypeList(); expectAndConsume('=', "type alias"); AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), name->name, generics, type, exported); + return allocator.alloc( + Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1333,7 +1339,7 @@ AstType* Parser::parseTableTypeAnnotation() // ReturnType ::= TypeAnnotation | `(' TypeList `)' // FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseFunctionTypeAnnotation() +AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1364,14 +1370,23 @@ AstType* Parser::parseFunctionTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; - // Not a function at all. Just a parenthesized type. + AstArray paramTypes = copy(params); + + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) - return params[0]; + { + if (allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; + else + return {params[0], {}}; + } + + if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; - AstArray paramTypes = copy(params); AstArray> paramNames = copy(names); - return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation); + return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, @@ -1421,7 +1436,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isUnion = true; } else if (c == '?') @@ -1434,7 +1449,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isIntersection = true; } else @@ -1462,6 +1477,30 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } +AstTypeOrPack Parser::parseTypeOrPackAnnotation() +{ + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + TempVector parts(scratchAnnotation); + + auto [type, typePack] = parseSimpleTypeAnnotation(true); + + if (typePack) + { + LUAU_ASSERT(!type); + return {{}, typePack}; + } + + parts.push_back(type); + + recursionCounter = oldRecursionCount; + + return {parseTypeAnnotation(parts, begin), {}}; +} + AstType* Parser::parseTypeAnnotation() { unsigned int oldRecursionCount = recursionCounter; @@ -1470,7 +1509,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); recursionCounter = oldRecursionCount; @@ -1479,7 +1518,7 @@ AstType* Parser::parseTypeAnnotation() // typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseSimpleTypeAnnotation() +AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1488,7 +1527,7 @@ AstType* Parser::parseSimpleTypeAnnotation() if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return allocator.alloc(begin, std::nullopt, nameNil); + return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } else if (lexer.current().type == Lexeme::Name) { @@ -1514,22 +1553,41 @@ AstType* Parser::parseSimpleTypeAnnotation() expectMatchAndConsume(')', typeofBegin); - return allocator.alloc(Location(begin, end), expr); + return {allocator.alloc(Location(begin, end), expr), {}}; } - AstArray generics = parseTypeParams(); + if (FFlag::LuauParseTypePackTypeParameters) + { + bool hasParameters = false; + AstArray parameters{}; + + if (lexer.current().type == '<') + { + hasParameters = true; + parameters = parseTypeParams(); + } + + Location end = lexer.previousLocation(); + + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; + } + else + { + AstArray generics = parseTypeParams(); - Location end = lexer.previousLocation(); + Location end = lexer.previousLocation(); - return allocator.alloc(Location(begin, end), prefix, name.name, generics); + // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks + return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; + } } else if (lexer.current().type == '{') { - return parseTableTypeAnnotation(); + return {parseTableTypeAnnotation(), {}}; } else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) { - return parseFunctionTypeAnnotation(); + return parseFunctionTypeAnnotation(allowPack); } else { @@ -1538,7 +1596,7 @@ AstType* Parser::parseSimpleTypeAnnotation() // For a missing type annoation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); - return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -2312,18 +2370,59 @@ std::pair, AstArray> Parser::parseGenericTypeList() return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams() { - TempVector result{scratchAnnotation}; + TempVector parameters{scratchTypeOrPackAnnotation}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); nextLexeme(); + bool seenPack = false; while (true) { - result.push_back(parseTypeAnnotation()); + if (FFlag::LuauParseTypePackTypeParameters) + { + if (shouldParseTypePackAnnotation(lexer)) + { + seenPack = true; + + auto typePack = parseTypePackAnnotation(); + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + { + seenPack = true; + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else + { + parameters.push_back({type, {}}); + } + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + if (lexer.current().type == ',') nextLexeme(); else @@ -2333,7 +2432,7 @@ AstArray Parser::parseTypeParams() expectMatchAndConsume('>', begin); } - return copy(result); + return copy(parameters); } AstExpr* Parser::parseString() diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp new file mode 100644 index 000000000..e6aab20e5 --- /dev/null +++ b/Ast/src/TimeTrace.cpp @@ -0,0 +1,248 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TimeTrace.h" + +#include "Luau/StringUtils.h" + +#include +#include + +#include + +#ifdef _WIN32 +#include +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +static double getClockPeriod() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceFrequency(&result); + return 1.0 / double(result.QuadPart); +#elif defined(__APPLE__) + mach_timebase_info_data_t result = {}; + mach_timebase_info(&result); + return double(result.numer) / double(result.denom) * 1e-9; +#elif defined(__linux__) + return 1e-9; +#else + return 1.0 / double(CLOCKS_PER_SEC); +#endif +} + +static double getClockTimestamp() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceCounter(&result); + return double(result.QuadPart); +#elif defined(__APPLE__) + return double(mach_absolute_time()); +#elif defined(__linux__) + timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_sec * 1e9 + now.tv_nsec; +#else + return double(clock()); +#endif +} + +uint32_t getClockMicroseconds() +{ + static double period = getClockPeriod() * 1e6; + static double start = getClockTimestamp(); + + return uint32_t((getClockTimestamp() - start) * period); +} + +struct GlobalContext +{ + GlobalContext() = default; + ~GlobalContext() + { + // Ideally we would want all ThreadContext destructors to run + // But in VS, not all thread_local object instances are destroyed + for (ThreadContext* context : threads) + context->flushEvents(); + + if (traceFile) + fclose(traceFile); + } + + std::mutex mutex; + std::vector threads; + uint32_t nextThreadId = 0; + std::vector tokens; + FILE* traceFile = nullptr; +}; + +GlobalContext& getGlobalContext() +{ + static GlobalContext context; + return context; +} + +uint16_t createToken(GlobalContext& context, const char* name, const char* category) +{ + std::scoped_lock lock(context.mutex); + + LUAU_ASSERT(context.tokens.size() < 64 * 1024); + + context.tokens.push_back({name, category}); + return uint16_t(context.tokens.size() - 1); +} + +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + context.threads.push_back(threadContext); + + return ++context.nextThreadId; +} + +void releaseThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end()) + context.threads.erase(it); +} + +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data) +{ + std::scoped_lock lock(context.mutex); + + if (!context.traceFile) + { + context.traceFile = fopen("trace.json", "w"); + + if (!context.traceFile) + return; + + fprintf(context.traceFile, "[\n"); + } + + std::string temp; + const unsigned tempReserve = 64 * 1024; + temp.reserve(tempReserve); + + const char* rawData = data.data(); + + // Formatting state + bool unfinishedEnter = false; + bool unfinishedArgs = false; + + for (const Event& ev : events) + { + switch (ev.type) + { + case EventType::Enter: + { + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + Token& token = context.tokens[ev.token]; + + formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category, + ev.data.microsec, threadId); + unfinishedEnter = true; + } + break; + case EventType::Leave: + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + formatAppend(temp, + R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)" + "\n", + ev.data.microsec, threadId); + break; + case EventType::ArgName: + LUAU_ASSERT(unfinishedEnter); + + if (!unfinishedArgs) + { + formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos); + unfinishedArgs = true; + } + else + { + formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos); + } + break; + case EventType::ArgValue: + LUAU_ASSERT(unfinishedArgs); + formatAppend(temp, R"("%s")", rawData + ev.data.dataPos); + break; + } + + // Don't want to hit the string capacity and reallocate + if (temp.size() > tempReserve - 1024) + { + fwrite(temp.data(), 1, temp.size(), context.traceFile); + temp.clear(); + } + } + + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + fwrite(temp.data(), 1, temp.size(), context.traceFile); + fflush(context.traceFile); +} + +ThreadContext& getThreadContext() +{ + thread_local ThreadContext context; + return context; +} + +std::pair createScopeData(const char* name, const char* category) +{ + uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); + return {token, Luau::TimeTrace::getThreadContext()}; +} +} // namespace TimeTrace +} // namespace Luau + +#endif diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 920502b82..ed0552d74 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -111,11 +111,24 @@ struct CliFileResolver : Luau::FileResolver return Luau::SourceCode{*source, Luau::SourceCode::Module}; } + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override + { + if (Luau::AstExprConstantString* expr = node->as()) + { + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + + return {{name}}; + } + + return std::nullopt; + } + bool moduleExists(const Luau::ModuleName& name) const override { return !!readFile(name); } + std::optional fromAstFragment(Luau::AstExpr* expr) const override { return std::nullopt; @@ -130,11 +143,6 @@ struct CliFileResolver : Luau::FileResolver { return std::nullopt; } - - std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 022eccb70..797ee20dc 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4,6 +4,7 @@ #include "Luau/Parser.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/TimeTrace.h" #include #include @@ -137,6 +138,11 @@ struct Compiler uint32_t compileFunction(AstExprFunction* func) { + LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); + + if (func->debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value); + LUAU_ASSERT(!functions.contains(func)); LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty()); @@ -3686,6 +3692,8 @@ struct Compiler void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) { + LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block table imports @@ -3748,6 +3756,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) { + LUAU_TIMETRACE_SCOPE("compile", "Compiler"); + Allocator allocator; AstNameTable names(allocator); ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions); diff --git a/Sources.cmake b/Sources.cmake index 6f96f6aba..83ed52301 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/ParseOptions.h Ast/include/Luau/Parser.h Ast/include/Luau/StringUtils.h + Ast/include/Luau/TimeTrace.h Ast/src/Ast.cpp Ast/src/Confusables.cpp @@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE Ast/src/Location.cpp Ast/src/Parser.cpp Ast/src/StringUtils.cpp + Ast/src/TimeTrace.cpp ) # Luau.Compiler Sources @@ -46,6 +48,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Predicate.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h + Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/TopoSortStatements.h @@ -75,6 +78,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Module.cpp Analysis/src/Predicate.cpp Analysis/src/RequireTracer.cpp + Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp Analysis/src/TopoSortStatements.cpp @@ -188,6 +192,7 @@ if(TARGET Luau.UnitTest) tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp + tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.builtins.test.cpp tests/TypeInfer.classes.test.cpp diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ee4962a50..39a615978 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -8,9 +8,13 @@ #include "lmem.h" #include "lvm.h" +#if LUA_USE_LONGJMP +#include +#include +#else #include +#endif -#include #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) @@ -51,8 +55,8 @@ l_noret luaD_throw(lua_State* L, int errcode) longjmp(jb->buf, 1); } - if (L->global->panic) - L->global->panic(L, errcode); + if (L->global->cb.panic) + L->global->cb.panic(L, errcode); abort(); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 9b040fb50..510a9f548 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) +LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) + LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -810,6 +812,133 @@ static size_t singlestep(lua_State* L) return cost; } +static size_t gcstep(lua_State* L, size_t limit) +{ + size_t cost = 0; + global_State* g = L->global; + switch (g->gcstate) + { + case GCSpause: + { + markroot(L); /* start a new collection */ + break; + } + case GCSpropagate: + { + if (FFlag::LuauRescanGrayAgain) + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; + + g->gcstate = GCSpropagateagain; + } + } + else + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + } + break; + } + case GCSpropagateagain: + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + break; + } + case GCSsweepstring: + { + while (g->sweepstrgc < g->strt.size && cost < limit) + { + size_t traversedcount = 0; + sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPCOST; + } + + // nothing more to sweep? + if (g->sweepstrgc >= g->strt.size) + { + // sweep string buffer list and preserve used string count + uint32_t nuse = L->global->strt.nuse; + + size_t traversedcount = 0; + sweepwholelist(L, &g->strbufgc, &traversedcount); + + L->global->strt.nuse = nuse; + + g->gcstats.currcycle.sweepitems += traversedcount; + g->gcstate = GCSsweep; // end sweep-string phase + } + break; + } + case GCSsweep: + { + while (*g->sweepgc && cost < limit) + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPMAX * GC_SWEEPCOST; + } + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } + break; + } + default: + LUAU_ASSERT(0); + } + return cost; +} + static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) { // adjust for error using Proportional-Integral controller @@ -878,33 +1007,40 @@ void luaC_step(lua_State* L, bool assist) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (assist) - g->gcstats.currcycle.assistwork += lim; - else - g->gcstats.currcycle.explicitwork += lim; - int lastgcstate = g->gcstate; double lastttimestamp = lua_clock(); - // always perform at least one single step - do + if (FFlag::LuauConsolidatedStep) { - lim -= singlestep(L); + size_t work = gcstep(L, lim); - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) + if (assist) + g->gcstats.currcycle.assistwork += work; + else + g->gcstats.currcycle.explicitwork += work; + } + else + { + // always perform at least one single step + do { - GC_INTERRUPT(lastgcstate); + lim -= singlestep(L); + + // if we have switched to a different state, capture the duration of last stage + // this way we reduce the number of timer calls we make + if (lastgcstate != g->gcstate) + { + GC_INTERRUPT(lastgcstate); - double now = lua_clock(); + double now = lua_clock(); - recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); - lastttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); + lastttimestamp = now; + lastgcstate = g->gcstate; + } + } while (lim > 0 && g->gcstate != GCSpause); + } recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); @@ -931,7 +1067,14 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - GC_INTERRUPT(g->gcstate); + if (FFlag::LuauConsolidatedStep) + { + GC_INTERRUPT(lastgcstate); + } + else + { + GC_INTERRUPT(g->gcstate); + } } void luaC_fullgc(lua_State* L) @@ -957,7 +1100,10 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } finishGcCycleStats(g); @@ -968,7 +1114,10 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 090e183ff..de5788eb6 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,14 +9,8 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false) - LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) -bool lua_telemetry_table_move_oob_src_from = false; -bool lua_telemetry_table_move_oob_src_to = false; -bool lua_telemetry_table_move_oob_dst = false; - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -202,22 +196,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - if (DFFlag::LuauTableMoveTelemetry) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - lua_telemetry_table_move_oob_src_from = true; - if (!(e == nf || (e >= 1 && e <= nf))) - lua_telemetry_table_move_oob_src_to = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt + 1))) - lua_telemetry_table_move_oob_dst = true; - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 5f0ee922a..eed2862b1 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op) return op == LOP_PREPVARARGS || op == LOP_BREAK; } -// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv -LUAI_FUNC int luaB_inext(lua_State* L); -LUAI_FUNC int luaB_next(lua_State* L); - template static void luau_execute(lua_State* L) { @@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: ipairs/inext - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext; - if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } @@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: pairs/next - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next; - if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 0a2323425..b932a85b3 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -12,7 +12,32 @@ #include -#include +// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens +template +struct TempBuffer +{ + lua_State* L; + T* data; + size_t count; + + TempBuffer(lua_State* L, size_t count) + : L(L) + , data(luaM_newarray(L, count, T, 0)) + , count(count) + { + } + + ~TempBuffer() + { + luaM_freearray(L, data, count, T, 0); + } + + T& operator[](size_t index) + { + LUAU_ASSERT(index < count); + return data[index]; + } +}; void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) { @@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset) return result; } -static TString* readString(std::vector& strings, const char* data, size_t size, size_t& offset) +static TString* readString(TempBuffer& strings, const char* data, size_t size, size_t& offset) { unsigned int id = readVarInt(data, size, offset); @@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size } // pause GC for the duration of deserialization - some objects we're creating aren't rooted + // TODO: if an allocation error happens mid-load, we do not unpause GC! size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; @@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // string table unsigned int stringCount = readVarInt(data, size, offset); - std::vector strings(stringCount); + TempBuffer strings(L, stringCount); for (unsigned int i = 0; i < stringCount; ++i) { @@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // proto table unsigned int protoCount = readVarInt(data, size, offset); - std::vector protos(protoCount); + TempBuffer protos(L, protoCount); for (unsigned int i = 0; i < protoCount; ++i) { diff --git a/bench/tests/deltablue.lua b/bench/tests/deltablue.lua deleted file mode 100644 index ecf246d30..000000000 --- a/bench/tests/deltablue.lua +++ /dev/null @@ -1,934 +0,0 @@ -local bench = script and require(script.Parent.bench_support) or require("bench_support") - --- Copyright 2008 the V8 project authors. All rights reserved. --- Copyright 1996 John Maloney and Mario Wolczko. - --- This program is free software; you can redistribute it and/or modify --- it under the terms of the GNU General Public License as published by --- the Free Software Foundation; either version 2 of the License, or --- (at your option) any later version. --- --- This program is distributed in the hope that it will be useful, --- but WITHOUT ANY WARRANTY; without even the implied warranty of --- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the --- GNU General Public License for more details. --- --- You should have received a copy of the GNU General Public License --- along with this program; if not, write to the Free Software --- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - - --- This implementation of the DeltaBlue benchmark is derived --- from the Smalltalk implementation by John Maloney and Mario --- Wolczko. Some parts have been translated directly, whereas --- others have been modified more aggresively to make it feel --- more like a JavaScript program. - - --- --- A JavaScript implementation of the DeltaBlue constraint-solving --- algorithm, as described in: --- --- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver" --- Bjorn N. Freeman-Benson and John Maloney --- January 1990 Communications of the ACM, --- also available as University of Washington TR 89-08-06. --- --- Beware: this benchmark is written in a grotesque style where --- the constraint model is built by side-effects from constructors. --- I've kept it this way to avoid deviating too much from the original --- implementation. --- - -function class(base) - local T = {} - T.__index = T - - if base then - T.super = base - setmetatable(T, base) - end - - function T.new(...) - local O = {} - setmetatable(O, T) - O:constructor(...) - return O - end - - return T -end - -local planner - ---- O b j e c t M o d e l --- - -local function alert (...) print(...) end - -local OrderedCollection = class() - -function OrderedCollection:constructor() - self.elms = {} -end - -function OrderedCollection:add(elm) - self.elms[#self.elms + 1] = elm -end - -function OrderedCollection:at (index) - return self.elms[index] -end - -function OrderedCollection:size () - return #self.elms -end - -function OrderedCollection:removeFirst () - local e = self.elms[#self.elms] - self.elms[#self.elms] = nil - return e -end - -function OrderedCollection:remove (elm) - local index = 0 - local skipped = 0 - - for i = 1, #self.elms do - local value = self.elms[i] - if value ~= elm then - self.elms[index] = value - index = index + 1 - else - skipped = skipped + 1 - end - end - - local l = #self.elms - for i = 1, skipped do self.elms[l - i + 1] = nil end -end - --- --- S t r e n g t h --- - --- --- Strengths are used to measure the relative importance of constraints. --- New strengths may be inserted in the strength hierarchy without --- disrupting current constraints. Strengths cannot be created outside --- this class, so pointer comparison can be used for value comparison. --- - -local Strength = class() - -function Strength:constructor(strengthValue, name) - self.strengthValue = strengthValue - self.name = name -end - -function Strength.stronger (s1, s2) - return s1.strengthValue < s2.strengthValue -end - -function Strength.weaker (s1, s2) - return s1.strengthValue > s2.strengthValue -end - -function Strength.weakestOf (s1, s2) - return Strength.weaker(s1, s2) and s1 or s2 -end - -function Strength.strongest (s1, s2) - return Strength.stronger(s1, s2) and s1 or s2 -end - -function Strength:nextWeaker () - local v = self.strengthValue - if v == 0 then return Strength.WEAKEST - elseif v == 1 then return Strength.WEAK_DEFAULT - elseif v == 2 then return Strength.NORMAL - elseif v == 3 then return Strength.STRONG_DEFAULT - elseif v == 4 then return Strength.PREFERRED - elseif v == 5 then return Strength.REQUIRED - end -end - --- Strength constants. -Strength.REQUIRED = Strength.new(0, "required"); -Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred"); -Strength.PREFERRED = Strength.new(2, "preferred"); -Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault"); -Strength.NORMAL = Strength.new(4, "normal"); -Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault"); -Strength.WEAKEST = Strength.new(6, "weakest"); - --- --- C o n s t r a i n t --- - --- --- An abstract class representing a system-maintainable relationship --- (or "constraint") between a set of variables. A constraint supplies --- a strength instance variable; concrete subclasses provide a means --- of storing the constrained variables and other information required --- to represent a constraint. --- - -local Constraint = class () - -function Constraint:constructor(strength) - self.strength = strength -end - --- --- Activate this constraint and attempt to satisfy it. --- -function Constraint:addConstraint () - self:addToGraph() - planner:incrementalAdd(self) -end - --- --- Attempt to find a way to enforce this constraint. If successful, --- record the solution, perhaps modifying the current dataflow --- graph. Answer the constraint that this constraint overrides, if --- there is one, or nil, if there isn't. --- Assume: I am not already satisfied. --- -function Constraint:satisfy (mark) - self:chooseMethod(mark) - if not self:isSatisfied() then - if self.strength == Strength.REQUIRED then - alert("Could not satisfy a required constraint!") - end - return nil - end - self:markInputs(mark) - local out = self:output() - local overridden = out.determinedBy - if overridden ~= nil then overridden:markUnsatisfied() end - out.determinedBy = self - if not planner:addPropagate(self, mark) then alert("Cycle encountered") end - out.mark = mark - return overridden -end - -function Constraint:destroyConstraint () - if self:isSatisfied() - then planner:incrementalRemove(self) - else self:removeFromGraph() - end -end - --- --- Normal constraints are not input constraints. An input constraint --- is one that depends on external state, such as the mouse, the --- keybord, a clock, or some arbitraty piece of imperative code. --- -function Constraint:isInput () - return false -end - - --- --- U n a r y C o n s t r a i n t --- - --- --- Abstract superclass for constraints having a single possible output --- variable. --- - -local UnaryConstraint = class(Constraint) - -function UnaryConstraint:constructor (v, strength) - UnaryConstraint.super.constructor(self, strength) - self.myOutput = v - self.satisfied = false - self:addConstraint() -end - --- --- Adds this constraint to the constraint graph --- -function UnaryConstraint:addToGraph () - self.myOutput:addConstraint(self) - self.satisfied = false -end - --- --- Decides if this constraint can be satisfied and records that --- decision. --- -function UnaryConstraint:chooseMethod (mark) - self.satisfied = (self.myOutput.mark ~= mark) - and Strength.stronger(self.strength, self.myOutput.walkStrength); -end - --- --- Returns true if this constraint is satisfied in the current solution. --- -function UnaryConstraint:isSatisfied () - return self.satisfied; -end - -function UnaryConstraint:markInputs (mark) - -- has no inputs -end - --- --- Returns the current output variable. --- -function UnaryConstraint:output () - return self.myOutput -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function UnaryConstraint:recalculate () - self.myOutput.walkStrength = self.strength - self.myOutput.stay = not self:isInput() - if self.myOutput.stay then - self:execute() -- Stay optimization - end -end - --- --- Records that this constraint is unsatisfied --- -function UnaryConstraint:markUnsatisfied () - self.satisfied = false -end - -function UnaryConstraint:inputsKnown () - return true -end - -function UnaryConstraint:removeFromGraph () - if self.myOutput ~= nil then - self.myOutput:removeConstraint(self) - end - self.satisfied = false -end - --- --- S t a y C o n s t r a i n t --- - --- --- Variables that should, with some level of preference, stay the same. --- Planners may exploit the fact that instances, if satisfied, will not --- change their output during plan execution. This is called "stay --- optimization". --- - -local StayConstraint = class(UnaryConstraint) - -function StayConstraint:constructor(v, str) - StayConstraint.super.constructor(self, v, str) -end - -function StayConstraint:execute () - -- Stay constraints do nothing -end - --- --- E d i t C o n s t r a i n t --- - --- --- A unary input constraint used to mark a variable that the client --- wishes to change. --- - -local EditConstraint = class (UnaryConstraint) - -function EditConstraint:constructor(v, str) - EditConstraint.super.constructor(self, v, str) -end - --- --- Edits indicate that a variable is to be changed by imperative code. --- -function EditConstraint:isInput () - return true -end - -function EditConstraint:execute () - -- Edit constraints do nothing -end - --- --- B i n a r y C o n s t r a i n t --- - -local Direction = {} -Direction.NONE = 0 -Direction.FORWARD = 1 -Direction.BACKWARD = -1 - --- --- Abstract superclass for constraints having two possible output --- variables. --- - -local BinaryConstraint = class(Constraint) - -function BinaryConstraint:constructor(var1, var2, strength) - BinaryConstraint.super.constructor(self, strength); - self.v1 = var1 - self.v2 = var2 - self.direction = Direction.NONE - self:addConstraint() -end - - --- --- Decides if this constraint can be satisfied and which way it --- should flow based on the relative strength of the variables related, --- and record that decision. --- -function BinaryConstraint:chooseMethod (mark) - if self.v1.mark == mark then - self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE - end - if self.v2.mark == mark then - self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE - end - if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then - self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE - else - self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD - end -end - --- --- Add this constraint to the constraint graph --- -function BinaryConstraint:addToGraph () - self.v1:addConstraint(self) - self.v2:addConstraint(self) - self.direction = Direction.NONE -end - --- --- Answer true if this constraint is satisfied in the current solution. --- -function BinaryConstraint:isSatisfied () - return self.direction ~= Direction.NONE -end - --- --- Mark the input variable with the given mark. --- -function BinaryConstraint:markInputs (mark) - self:input().mark = mark -end - --- --- Returns the current input variable --- -function BinaryConstraint:input () - return (self.direction == Direction.FORWARD) and self.v1 or self.v2 -end - --- --- Returns the current output variable --- -function BinaryConstraint:output () - return (self.direction == Direction.FORWARD) and self.v2 or self.v1 -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this --- constraint. Assume this constraint is satisfied. --- -function BinaryConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength); - out.stay = ihn.stay - if out.stay then self:execute() end -end - --- --- Record the fact that self constraint is unsatisfied. --- -function BinaryConstraint:markUnsatisfied () - self.direction = Direction.NONE -end - -function BinaryConstraint:inputsKnown (mark) - local i = self:input() - return i.mark == mark or i.stay or i.determinedBy == nil -end - -function BinaryConstraint:removeFromGraph () - if (self.v1 ~= nil) then self.v1:removeConstraint(self) end - if (self.v2 ~= nil) then self.v2:removeConstraint(self) end - self.direction = Direction.NONE -end - --- --- S c a l e C o n s t r a i n t --- - --- --- Relates two variables by the linear scaling relationship: "v2 = --- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain --- this relationship but the scale factor and offset are considered --- read-only. --- - -local ScaleConstraint = class (BinaryConstraint) - -function ScaleConstraint:constructor(src, scale, offset, dest, strength) - self.direction = Direction.NONE - self.scale = scale - self.offset = offset - ScaleConstraint.super.constructor(self, src, dest, strength) -end - - --- --- Adds this constraint to the constraint graph. --- -function ScaleConstraint:addToGraph () - ScaleConstraint.super.addToGraph(self) - self.scale:addConstraint(self) - self.offset:addConstraint(self) -end - -function ScaleConstraint:removeFromGraph () - ScaleConstraint.super.removeFromGraph(self) - if (self.scale ~= nil) then self.scale:removeConstraint(self) end - if (self.offset ~= nil) then self.offset:removeConstraint(self) end -end - -function ScaleConstraint:markInputs (mark) - ScaleConstraint.super.markInputs(self, mark); - self.offset.mark = mark - self.scale.mark = mark -end - --- --- Enforce this constraint. Assume that it is satisfied. --- -function ScaleConstraint:execute () - if self.direction == Direction.FORWARD then - self.v2.value = self.v1.value * self.scale.value + self.offset.value - else - self.v1.value = (self.v2.value - self.offset.value) / self.scale.value - end -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function ScaleConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength) - out.stay = ihn.stay and self.scale.stay and self.offset.stay - if out.stay then self:execute() end -end - --- --- E q u a l i t y C o n s t r a i n t --- - --- --- Constrains two variables to have the same value. --- - -local EqualityConstraint = class (BinaryConstraint) - -function EqualityConstraint:constructor(var1, var2, strength) - EqualityConstraint.super.constructor(self, var1, var2, strength) -end - - --- --- Enforce this constraint. Assume that it is satisfied. --- -function EqualityConstraint:execute () - self:output().value = self:input().value -end - --- --- V a r i a b l e --- - --- --- A constrained variable. In addition to its value, it maintain the --- structure of the constraint graph, the current dataflow graph, and --- various parameters of interest to the DeltaBlue incremental --- constraint solver. --- -local Variable = class () - -function Variable:constructor(name, initialValue) - self.value = initialValue or 0 - self.constraints = OrderedCollection.new() - self.determinedBy = nil - self.mark = 0 - self.walkStrength = Strength.WEAKEST - self.stay = true - self.name = name -end - --- --- Add the given constraint to the set of all constraints that refer --- this variable. --- -function Variable:addConstraint (c) - self.constraints:add(c) -end - --- --- Removes all traces of c from this variable. --- -function Variable:removeConstraint (c) - self.constraints:remove(c) - if self.determinedBy == c then - self.determinedBy = nil - end -end - --- --- P l a n n e r --- - --- --- The DeltaBlue planner --- -local Planner = class() -function Planner:constructor() - self.currentMark = 0 -end - --- --- Attempt to satisfy the given constraint and, if successful, --- incrementally update the dataflow graph. Details: If satifying --- the constraint is successful, it may override a weaker constraint --- on its output. The algorithm attempts to resatisfy that --- constraint using some other method. This process is repeated --- until either a) it reaches a variable that was not previously --- determined by any constraint or b) it reaches a constraint that --- is too weak to be satisfied using any of its methods. The --- variables of constraints that have been processed are marked with --- a unique mark value so that we know where we've been. This allows --- the algorithm to avoid getting into an infinite loop even if the --- constraint graph has an inadvertent cycle. --- -function Planner:incrementalAdd (c) - local mark = self:newMark() - local overridden = c:satisfy(mark) - while overridden ~= nil do - overridden = overridden:satisfy(mark) - end -end - --- --- Entry point for retracting a constraint. Remove the given --- constraint and incrementally update the dataflow graph. --- Details: Retracting the given constraint may allow some currently --- unsatisfiable downstream constraint to be satisfied. We therefore collect --- a list of unsatisfied downstream constraints and attempt to --- satisfy each one in turn. This list is traversed by constraint --- strength, strongest first, as a heuristic for avoiding --- unnecessarily adding and then overriding weak constraints. --- Assume: c is satisfied. --- -function Planner:incrementalRemove (c) - local out = c:output() - c:markUnsatisfied() - c:removeFromGraph() - local unsatisfied = self:removePropagateFrom(out) - local strength = Strength.REQUIRED - repeat - for i = 1, unsatisfied:size() do - local u = unsatisfied:at(i) - if u.strength == strength then - self:incrementalAdd(u) - end - end - strength = strength:nextWeaker() - until strength == Strength.WEAKEST -end - --- --- Select a previously unused mark value. --- -function Planner:newMark () - self.currentMark = self.currentMark + 1 - return self.currentMark -end - --- --- Extract a plan for resatisfaction starting from the given source --- constraints, usually a set of input constraints. This method --- assumes that stay optimization is desired; the plan will contain --- only constraints whose output variables are not stay. Constraints --- that do no computation, such as stay and edit constraints, are --- not included in the plan. --- Details: The outputs of a constraint are marked when it is added --- to the plan under construction. A constraint may be appended to --- the plan when all its input variables are known. A variable is --- known if either a) the variable is marked (indicating that has --- been computed by a constraint appearing earlier in the plan), b) --- the variable is 'stay' (i.e. it is a constant at plan execution --- time), or c) the variable is not determined by any --- constraint. The last provision is for past states of history --- variables, which are not stay but which are also not computed by --- any constraint. --- Assume: sources are all satisfied. --- -local Plan -- FORWARD DECLARATION -function Planner:makePlan (sources) - local mark = self:newMark() - local plan = Plan.new() - local todo = sources - while todo:size() > 0 do - local c = todo:removeFirst() - if c:output().mark ~= mark and c:inputsKnown(mark) then - plan:addConstraint(c) - c:output().mark = mark - self:addConstraintsConsumingTo(c:output(), todo) - end - end - return plan -end - --- --- Extract a plan for resatisfying starting from the output of the --- given constraints, usually a set of input constraints. --- -function Planner:extractPlanFromConstraints (constraints) - local sources = OrderedCollection.new() - for i = 1, constraints:size() do - local c = constraints:at(i) - if c:isInput() and c:isSatisfied() then - -- not in plan already and eligible for inclusion - sources:add(c) - end - end - return self:makePlan(sources) -end - --- --- Recompute the walkabout strengths and stay flags of all variables --- downstream of the given constraint and recompute the actual --- values of all variables whose stay flag is true. If a cycle is --- detected, remove the given constraint and answer --- false. Otherwise, answer true. --- Details: Cycles are detected when a marked variable is --- encountered downstream of the given constraint. The sender is --- assumed to have marked the inputs of the given constraint with --- the given mark. Thus, encountering a marked node downstream of --- the output constraint means that there is a path from the --- constraint's output to one of its inputs. --- -function Planner:addPropagate (c, mark) - local todo = OrderedCollection.new() - todo:add(c) - while todo:size() > 0 do - local d = todo:removeFirst() - if d:output().mark == mark then - self:incrementalRemove(c) - return false - end - d:recalculate() - self:addConstraintsConsumingTo(d:output(), todo) - end - return true -end - - --- --- Update the walkabout strengths and stay flags of all variables --- downstream of the given constraint. Answer a collection of --- unsatisfied constraints sorted in order of decreasing strength. --- -function Planner:removePropagateFrom (out) - out.determinedBy = nil - out.walkStrength = Strength.WEAKEST - out.stay = true - local unsatisfied = OrderedCollection.new() - local todo = OrderedCollection.new() - todo:add(out) - while todo:size() > 0 do - local v = todo:removeFirst() - for i = 1, v.constraints:size() do - local c = v.constraints:at(i) - if not c:isSatisfied() then unsatisfied:add(c) end - end - local determining = v.determinedBy - for i = 1, v.constraints:size() do - local next = v.constraints:at(i); - if next ~= determining and next:isSatisfied() then - next:recalculate() - todo:add(next:output()) - end - end - end - return unsatisfied -end - -function Planner:addConstraintsConsumingTo (v, coll) - local determining = v.determinedBy - local cc = v.constraints - for i = 1, cc:size() do - local c = cc:at(i) - if c ~= determining and c:isSatisfied() then - coll:add(c) - end - end -end - --- --- P l a n --- - --- --- A Plan is an ordered list of constraints to be executed in sequence --- to resatisfy all currently satisfiable constraints in the face of --- one or more changing inputs. --- -Plan = class() -function Plan:constructor() - self.v = OrderedCollection.new() -end - -function Plan:addConstraint (c) - self.v:add(c) -end - -function Plan:size () - return self.v:size() -end - -function Plan:constraintAt (index) - return self.v:at(index) -end - -function Plan:execute () - for i = 1, self:size() do - local c = self:constraintAt(i) - c:execute() - end -end - --- --- M a i n --- - --- --- This is the standard DeltaBlue benchmark. A long chain of equality --- constraints is constructed with a stay constraint on one end. An --- edit constraint is then added to the opposite end and the time is --- measured for adding and removing this constraint, and extracting --- and executing a constraint satisfaction plan. There are two cases. --- In case 1, the added constraint is stronger than the stay --- constraint and values must propagate down the entire length of the --- chain. In case 2, the added constraint is weaker than the stay --- constraint so it cannot be accomodated. The cost in this case is, --- of course, very low. Typical situations lie somewhere between these --- two extremes. --- -local function chainTest(n) - planner = Planner.new() - local prev = nil - local first = nil - local last = nil - - -- Build chain of n equality constraints - for i = 0, n do - local name = "v" .. i; - local v = Variable.new(name) - if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end - if i == 0 then first = v end - if i == n then last = v end - prev = v - end - - StayConstraint.new(last, Strength.STRONG_DEFAULT) - local edit = EditConstraint.new(first, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 0, 99 do - first.value = i - plan:execute() - if last.value ~= i then - alert("Chain test failed.") - end - end -end - -local function change(v, newValue) - local edit = EditConstraint.new(v, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 1, 10 do - v.value = newValue - plan:execute() - end - edit:destroyConstraint() -end - --- --- This test constructs a two sets of variables related to each --- other by a simple linear transformation (scale and offset). The --- time is measured to change a variable on either side of the --- mapping and to change the scale and offset factors. --- -local function projectionTest(n) - planner = Planner.new(); - local scale = Variable.new("scale", 10); - local offset = Variable.new("offset", 1000); - local src = nil - local dst = nil; - - local dests = OrderedCollection.new(); - for i = 0, n - 1 do - src = Variable.new("src" .. i, i); - dst = Variable.new("dst" .. i, i); - dests:add(dst); - StayConstraint.new(src, Strength.NORMAL); - ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED); - end - - change(src, 17) - if dst.value ~= 1170 then alert("Projection 1 failed") end - change(dst, 1050) - if src.value ~= 5 then alert("Projection 2 failed") end - change(scale, 5) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 1000 then - alert("Projection 3 failed") - end - end - change(offset, 2000) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 2000 then - alert("Projection 4 failed") - end - end -end - -function test() - local t0 = os.clock() - chainTest(1000); - projectionTest(1000); - local t1 = os.clock() - return t1-t0 -end - -bench.runCode(test, "deltablue") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9cd642cfd..07910a0ac 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -23,19 +23,17 @@ static std::optional nullCallback(std::string tag, std::op return std::nullopt; } -struct ACFixture : Fixture +template +struct ACFixtureImpl : BaseType { AutocompleteResult autocomplete(unsigned row, unsigned column) { - return Luau::autocomplete(frontend, "MainModule", Position{row, column}, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker) { - auto i = markerPosition.find(marker); - LUAU_ASSERT(i != markerPosition.end()); - const Position& pos = i->second; - return Luau::autocomplete(frontend, "MainModule", pos, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), nullCallback); } CheckResult check(const std::string& source) @@ -45,16 +43,18 @@ struct ACFixture : Fixture filteredSource.reserve(source.size()); Position curPos(0, 0); + char prevChar{}; for (char c : source) { - if (c == '@' && !filteredSource.empty()) + if (prevChar == '@') + { + LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); + markerPosition.insert(std::pair{c, curPos}); + } + else if (c == '@') { - char prevChar = filteredSource.back(); - filteredSource.pop_back(); - curPos.column--; // Adjust column position since we removed a character from the output - LUAU_ASSERT("Illegal marker character" && prevChar >= '0' && prevChar <= '9'); - LUAU_ASSERT("Duplicate marker found" && markerPosition.count(prevChar) == 0); - markerPosition.insert(std::pair{prevChar, curPos}); + // skip the '@' character } else { @@ -69,22 +69,39 @@ struct ACFixture : Fixture curPos.column++; } } + prevChar = c; } + LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); return Fixture::check(filteredSource); } + const Position& getPosition(char marker) const + { + auto i = markerPosition.find(marker); + LUAU_ASSERT(i != markerPosition.end()); + return i->second; + } + // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; +struct ACFixture : ACFixtureImpl +{ +}; + +struct UnfrozenACFixture : ACFixtureImpl +{ +}; + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") { - check(" "); + check(" @1"); - auto ac = autocomplete(0, 1); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("table")); @@ -93,26 +110,26 @@ TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "local_initializer") { - check("local a = "); + check("local a = @1"); - auto ac = autocomplete(0, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); } TEST_CASE_FIXTURE(ACFixture, "leave_numbers_alone") { - check("local a = 3.1"); + check("local a = 3.@11"); - auto ac = autocomplete(0, 12); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "user_defined_globals") { - check("local myLocal = 4; "); + check("local myLocal = 4; @1"); - auto ac = autocomplete(0, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("table")); @@ -124,20 +141,20 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") check(R"( local myLocal = 4 function abc() - local myInnerLocal = 1 - +@1 local myInnerLocal = 1 +@2 end - )"); +@3 )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); - ac = autocomplete(4, 0); + ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("myInnerLocal")); - ac = autocomplete(6, 0); + ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); } @@ -146,10 +163,10 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function") { check(R"( function foo() - end +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("foo")); } @@ -158,11 +175,11 @@ TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") check(R"( local function outer() local function inner() - end +@1 end end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("inner")); CHECK(ac.entryMap.count("outer")); } @@ -171,11 +188,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") { check(R"( local function abc() - +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("table")); @@ -183,11 +200,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") check(R"( local abc = function() - +@1 end )"); - ac = autocomplete(2, 0); + ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); // FIXME: This is actually incorrect! CHECK(ac.entryMap.count("table")); @@ -202,9 +219,9 @@ TEST_CASE_FIXTURE(ACFixture, "global_functions_are_not_scoped_lexically") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("abc")); @@ -220,9 +237,9 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(!ac.entryMap.count("abc")); @@ -233,10 +250,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") check(R"( function abc(test) - end +@1 end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("test")); } @@ -244,11 +261,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = table. -- Line 1 - -- | Column 23 + local a = table.@1 )"); - auto ac = autocomplete(1, 24); + auto ac = autocomplete('1'); CHECK_EQ(16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); @@ -260,10 +276,10 @@ TEST_CASE_FIXTURE(ACFixture, "nested_member_completions") { check(R"( local tbl = { abc = { def = 1234, egh = false } } - tbl.abc. + tbl.abc. @1 )"); - auto ac = autocomplete(2, 17); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("def")); CHECK(ac.entryMap.count("egh")); @@ -274,10 +290,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table") check(R"( local tbl = {} tbl.prop = 5 - tbl. + tbl.@1 )"); - auto ac = autocomplete(3, 12); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -288,10 +304,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") local tbl = {} local inner = { prop = 5 } tbl.inner = inner - tbl.inner. + tbl.inner. @1 )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -302,10 +318,10 @@ TEST_CASE_FIXTURE(ACFixture, "cyclic_table") local abc = {} local def = { abc = abc } abc.def = def - abc.def. + abc.def. @1 )"); - auto ac = autocomplete(4, 17); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); } @@ -315,11 +331,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_union") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 | t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("b2")); } @@ -330,11 +346,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 & t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(3, ac.entryMap.size()); CHECK(ac.entryMap.count("a1")); CHECK(ac.entryMap.count("b2")); @@ -344,20 +360,19 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") TEST_CASE_FIXTURE(ACFixture, "get_string_completions") { check(R"( - local a = ("foo"): -- Line 1 - -- | Column 26 + local a = ("foo"):@1 )"); - auto ac = autocomplete(1, 26); + auto ac = autocomplete('1'); CHECK_EQ(17, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") { - check(""); + check("@1"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -366,12 +381,12 @@ TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_the_very_start_of_the_script") { - check(R"( + check(R"(@1 function aaa() end )"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } @@ -382,11 +397,11 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") local game = { GetService=function(s) return 'hello' end } function a() - game: + game: @1 end )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -396,10 +411,10 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") { check(R"( - if table: + if table: @1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); @@ -411,12 +426,12 @@ TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") check(R"( function getmyscripts() end - g + g@1 getmyscripts() )"); - auto ac = autocomplete(3, 9); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -431,11 +446,11 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") function B() local A = {two=2} - A + A @1 end )"); - auto ac = autocomplete(6, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("A")); @@ -448,12 +463,12 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") TEST_CASE_FIXTURE(ACFixture, "recommend_statement_starting_keywords") { - check(""); - auto ac = autocomplete(0, 0); + check("@1"); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("local")); - check("local i = "); - auto ac2 = autocomplete(0, 10); + check("local i = @1"); + auto ac2 = autocomplete('1'); CHECK(!ac2.entryMap.count("local")); } @@ -464,9 +479,9 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") end - )"); +@1 )"); - auto ac = autocomplete(5, 0); + auto ac = autocomplete('1'); AutocompleteEntry entry = ac.entryMap["continue"]; CHECK(entry.kind == AutocompleteEntryKind::Binding); @@ -480,11 +495,11 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") function foo:bar() end --[[ - foo: + foo:@1 ]] )"); - auto ac = autocomplete(6, 16); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -492,10 +507,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comment") { check(R"( - --!strict + --!strict@1 )"); - auto ac = autocomplete(1, 17); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -505,10 +520,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; check(R"( - --[[ + --[[ @1 )"); - auto ac = autocomplete(1, 13); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -517,129 +532,129 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co { ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[["); + check("--[[@1"); - auto ac = autocomplete(0, 4); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") { check(R"( - for x = + for x @1= )"); - auto ac1 = autocomplete(1, 14); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - for x = 1 + for x =@1 1 )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("do"), 0); CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - for x = 1, 2 + for x = 1,@1 2 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("do"), 1); check(R"( - for x = 1, 2, + for x = 1, @12, )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("do"), 0); CHECK_EQ(ac4.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 + for x = 1, 2, @15 )"); - auto ac5 = autocomplete(1, 22); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("do"), 1); CHECK_EQ(ac5.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 f + for x = 1, 2, 5 f@1 )"); - auto ac6 = autocomplete(1, 25); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.size(), 1); CHECK_EQ(ac6.entryMap.count("do"), 1); check(R"( - for x = 1, 2, 5 do + for x = 1, 2, 5 do @1 )"); - auto ac7 = autocomplete(1, 32); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") { check(R"( - for + for @1 )"); - auto ac1 = autocomplete(1, 12); + auto ac1 = autocomplete('1'); CHECK_EQ(0, ac1.entryMap.size()); check(R"( - for x + for x@1 @2 )"); - auto ac2 = autocomplete(1, 13); + auto ac2 = autocomplete('1'); CHECK_EQ(0, ac2.entryMap.size()); - auto ac2a = autocomplete(1, 14); + auto ac2a = autocomplete('2'); CHECK_EQ(1, ac2a.entryMap.size()); CHECK_EQ(1, ac2a.entryMap.count("in")); check(R"( - for x in y + for x in y@1 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("table"), 1); CHECK_EQ(ac3.entryMap.count("do"), 0); check(R"( - for x in y + for x in y @1 )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.size(), 1); CHECK_EQ(ac4.entryMap.count("do"), 1); check(R"( - for x in f f + for x in f f@1 )"); - auto ac5 = autocomplete(1, 20); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.size(), 1); CHECK_EQ(ac5.entryMap.count("do"), 1); check(R"( - for x in y do + for x in y do @1 )"); - auto ac6 = autocomplete(1, 23); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.count("in"), 0); CHECK_EQ(ac6.entryMap.count("table"), 1); CHECK_EQ(ac6.entryMap.count("end"), 1); CHECK_EQ(ac6.entryMap.count("function"), 1); check(R"( - for x in y do e + for x in y do e@1 )"); - auto ac7 = autocomplete(1, 23); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("in"), 0); CHECK_EQ(ac7.entryMap.count("table"), 1); CHECK_EQ(ac7.entryMap.count("end"), 1); @@ -649,33 +664,33 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") { check(R"( - while + while@1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - while true + while true @1 )"); - auto ac2 = autocomplete(1, 19); + auto ac2 = autocomplete('1'); CHECK_EQ(1, ac2.entryMap.size()); CHECK_EQ(ac2.entryMap.count("do"), 1); check(R"( - while true do + while true do @1 )"); - auto ac3 = autocomplete(1, 23); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("end"), 1); check(R"( - while true d + while true d@1 )"); - auto ac4 = autocomplete(1, 20); + auto ac4 = autocomplete('1'); CHECK_EQ(1, ac4.entryMap.size()); CHECK_EQ(ac4.entryMap.count("do"), 1); } @@ -683,10 +698,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") { check(R"( - if + if @1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("then"), 0); CHECK_EQ(ac1.entryMap.count("function"), 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. @@ -696,10 +711,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - if x + if x @1 )"); - auto ac2 = autocomplete(1, 14); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("then"), 1); CHECK_EQ(ac2.entryMap.count("function"), 0); CHECK_EQ(ac2.entryMap.count("else"), 0); @@ -707,20 +722,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - if x t + if x t@1 )"); - auto ac3 = autocomplete(1, 14); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("then"), 1); check(R"( if x then - +@1 end )"); - auto ac4 = autocomplete(2, 0); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("then"), 0); CHECK_EQ(ac4.entryMap.count("else"), 1); CHECK_EQ(ac4.entryMap.count("function"), 1); @@ -729,11 +744,11 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - t + t@1 end )"); - auto ac4a = autocomplete(2, 13); + auto ac4a = autocomplete('1'); CHECK_EQ(ac4a.entryMap.count("then"), 0); CHECK_EQ(ac4a.entryMap.count("table"), 1); CHECK_EQ(ac4a.entryMap.count("else"), 1); @@ -741,12 +756,12 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - +@1 elseif x then end )"); - auto ac5 = autocomplete(2, 0); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("then"), 0); CHECK_EQ(ac5.entryMap.count("function"), 1); CHECK_EQ(ac5.entryMap.count("else"), 0); @@ -757,10 +772,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") { check(R"( - repeat + repeat @1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("until"), 1); } @@ -769,48 +784,48 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_expression") { check(R"( repeat - until + until @1 )"); - auto ac = autocomplete(2, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_names") { check(R"( - local ab + local ab@1 )"); - auto ac1 = autocomplete(1, 16); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local ab, cd + local ab, cd@1 )"); - auto ac2 = autocomplete(1, 20); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_fn_exprs") { check(R"( - local function f() + local function f() @1 )"); - auto ac = autocomplete(1, 28); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") { check(R"( - local a = function() local bar = foo en + local a = function() local bar = foo en@1 )"); - auto ac = autocomplete(1, 47); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } @@ -818,10 +833,10 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( repeat - for x + for x @1 )"); - auto ac1 = autocomplete(2, 18); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("in"), 1); CHECK_EQ(ac1.entryMap.count("until"), 0); } @@ -829,112 +844,112 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_repeat_middle_keyword") { check(R"( - repeat + repeat @1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); CHECK_EQ(ac1.entryMap.count("until"), 1); check(R"( - repeat f f + repeat f f@1 )"); - auto ac2 = autocomplete(1, 18); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("function"), 1); CHECK_EQ(ac2.entryMap.count("until"), 1); check(R"( repeat - u + u@1 until )"); - auto ac3 = autocomplete(2, 13); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("until"), 0); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local f + local f@1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local f, cd + local f@1, cd )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local function + local function @1 )"); - auto ac = autocomplete(1, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( - local function s + local function @1s@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 24); + ac = autocomplete('2'); CHECK(ac.entryMap.empty()); check(R"( - local function () + local function @1()@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 25); + ac = autocomplete('2'); CHECK(ac.entryMap.count("end")); check(R"( - local function something + local function something@1 )"); - ac = autocomplete(1, 32); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( local tbl = {} - function tbl.something() end + function tbl.something@1() end )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function_params") { check(R"( - local function abc(def) + local function @1a@2bc(@3d@4ef)@5 @6 )"); - CHECK(autocomplete(1, 23).entryMap.empty()); - CHECK(autocomplete(1, 24).entryMap.empty()); - CHECK(autocomplete(1, 27).entryMap.empty()); - CHECK(autocomplete(1, 28).entryMap.empty()); - CHECK(!autocomplete(1, 31).entryMap.empty()); + CHECK(autocomplete('1').entryMap.empty()); + CHECK(autocomplete('2').entryMap.empty()); + CHECK(autocomplete('3').entryMap.empty()); + CHECK(autocomplete('4').entryMap.empty()); + CHECK(!autocomplete('5').entryMap.empty()); - CHECK(!autocomplete(1, 32).entryMap.empty()); + CHECK(!autocomplete('6').entryMap.empty()); check(R"( local function abc(def) - end +@1 end )"); for (unsigned int i = 23; i < 31; ++i) @@ -943,16 +958,16 @@ TEST_CASE_FIXTURE(ACFixture, "local_function_params") } CHECK(!autocomplete(1, 32).entryMap.empty()); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - local function abc(def, ghi) + local function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 35); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } @@ -981,48 +996,48 @@ TEST_CASE_FIXTURE(ACFixture, "global_function_params") check(R"( function abc(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - function abc(def, ghi) + function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 29); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "arguments_to_global_lambda") { check(R"( - abc = function(def, ghi) + abc = function(def, ghi@1) end )"); - auto ac = autocomplete(1, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { check(R"( - abc = function(def) + abc = function(def) @1 )"); for (unsigned int i = 20; i < 27; ++i) { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( - abc = function(def) + abc = function(def) @1 end )"); @@ -1030,25 +1045,25 @@ TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( abc = function(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("def"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_initializer") { check(R"( - local a = t + local a = t@1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("true"), 1); } @@ -1056,20 +1071,20 @@ TEST_CASE_FIXTURE(ACFixture, "local_initializer") TEST_CASE_FIXTURE(ACFixture, "local_initializer_2") { check(R"( - local a= + local a=@1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = 12.3 + local a = 12.@13 )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -1083,21 +1098,21 @@ TEST_CASE_FIXTURE(ACFixture, "sometimes_the_metatable_is_an_error") return setmetatable({x=6}, X) -- oops! end local t = T.new() - t. + t. @1 )"); - autocomplete(8, 12); + autocomplete('1'); // Don't crash! } TEST_CASE_FIXTURE(ACFixture, "local_types_builtin") { check(R"( -local a: n +local a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1108,23 +1123,23 @@ TEST_CASE_FIXTURE(ACFixture, "private_types") check(R"( do type num = number - local a: nu - local b: num + local a: n@1u + local b: nu@2m end -local a: nu +local a: nu@3 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 15); + ac = autocomplete('2'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(6, 11); + ac = autocomplete('3'); CHECK(!ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); @@ -1136,11 +1151,11 @@ TEST_CASE_FIXTURE(ACFixture, "type_scoping_easy") type Table = { a: number, b: number } do type Table = { x: string, y: string } - local a: T + local a: T@1 end )"); - auto ac = autocomplete(4, 14); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("Table")); REQUIRE(ac.entryMap["Table"].type); @@ -1198,11 +1213,11 @@ local a: aaa. TEST_CASE_FIXTURE(ACFixture, "argument_types") { check(R"( -local function f(a: n +local function f(a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1211,11 +1226,11 @@ local b: string = "don't trip" TEST_CASE_FIXTURE(ACFixture, "return_types") { check(R"( -local function f(a: number): n +local function f(a: number): n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1225,10 +1240,10 @@ TEST_CASE_FIXTURE(ACFixture, "as_types") { check(R"( local a: any = 5 -local b: number = (a :: n +local b: number = (a :: n@1 )"); - auto ac = autocomplete(2, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1237,34 +1252,34 @@ local b: number = (a :: n TEST_CASE_FIXTURE(ACFixture, "function_type_types") { check(R"( -local a: (n -local b: (number, (n -local c: (number, (number) -> n -local d: (number, (number) -> (number, n -local e: (n: n +local a: (n@1 +local b: (number, (n@2 +local c: (number, (number) -> n@3 +local d: (number, (number) -> (number, n@4 +local e: (n: n@5 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(2, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(3, 31); + ac = autocomplete('3'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 40); + ac = autocomplete('4'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(5, 14); + ac = autocomplete('5'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1276,11 +1291,11 @@ TEST_CASE_FIXTURE(ACFixture, "generic_types") ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); check(R"( -function f(a: T +function f(a: T@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("Tee")); } @@ -1293,10 +1308,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(o +return target(o@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1307,10 +1322,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(one, t +return target(one, t@1 )"); - ac = autocomplete(5, 20); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1321,10 +1336,10 @@ return target(one, t local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1334,10 +1349,10 @@ return target(a. local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a.one, a. +return target(a.one, a.@1 )"); - ac = autocomplete(4, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1348,10 +1363,10 @@ return target(a.one, a. local function target(a: string?) return #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1363,10 +1378,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { a = a. +local b: Foo = { a = a.@1 )"); - auto ac = autocomplete(3, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1375,10 +1390,10 @@ local b: Foo = { a = a. check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { b = a. +local b: Foo = { b = a.@1 )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1392,10 +1407,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1406,10 +1421,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end -return target(bar1, b +return target(bar1, b@1 )"); - ac = autocomplete(5, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar2")); CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1420,10 +1435,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1433,69 +1448,69 @@ return target(b TEST_CASE_FIXTURE(ACFixture, "type_correct_local_type_suggestion") { check(R"( -local b: s = "str" +local b: s@1 = "str" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return "str" end -local b: s = f() +local b: s@1 = f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: s, c: n = "str", 2 +local b: s@1, c: n@2 = "str", 2 )"); - ac = autocomplete(1, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return 1, "str", 3 end -local a: b, b: n, c: s, d: n = false, f() +local a: b@1, b: n@2, c: s@3, d: n@4 = false, f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 22); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 28); + ac = autocomplete('4'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f(): ...number return 1, 2, 3 end -local a: boolean, b: n = false, f() +local a: boolean, b: n@1 = false, f() )"); - ac = autocomplete(2, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1504,46 +1519,46 @@ local a: boolean, b: n = false, f() TEST_CASE_FIXTURE(ACFixture, "type_correct_function_type_suggestion") { check(R"( -local b: (n) -> number = function(a: number, b: string) return a + #b end +local b: (n@1) -> number = function(a: number, b: string) return a + #b end )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, s = function(a: number, b: string) return a + #b end +local b: (number, s@1 = function(a: number, b: string) return a + #b end )"); - ac = autocomplete(1, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, string) -> b = function(a: number, b: string): boolean return a + #b == 0 end +local b: (number, string) -> b@1 = function(a: number, b: string): boolean return a + #b == 0 end )"); - ac = autocomplete(1, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, ...s) = function(a: number, ...: string) return a end +local b: (number, ...s@1) = function(a: number, ...: string) return a end )"); - ac = autocomplete(1, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" end +local b: (number) -> ...s@1 = function(a: number): ...string return "a", "b", "c" end )"); - ac = autocomplete(1, 25); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1552,24 +1567,24 @@ local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" TEST_CASE_FIXTURE(ACFixture, "type_correct_full_type_suggestion") { check(R"( -local b: = "str" +local b:@1 @2= "str" )"); - auto ac = autocomplete(1, 8); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 9); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: = function(a: number) return -a end +local b: @1= function(a: number) return -a end )"); - ac = autocomplete(1, 9); + ac = autocomplete('1'); CHECK(ac.entryMap.count("(number) -> number")); CHECK(ac.entryMap["(number) -> number"].typeCorrect == TypeCorrectKind::Correct); @@ -1580,12 +1595,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: n, b) +local function d(a: n@1, b) return target(a, b) end )"); - auto ac = autocomplete(3, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1593,12 +1608,12 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: s) +local function d(a, b: s@1) return target(a, b) end )"); - ac = autocomplete(3, 24); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1606,17 +1621,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: , b) +local function d(a:@1 @2, b) return target(a, b) end )"); - ac = autocomplete(3, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1624,17 +1639,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: ): number +local function d(a, b: @1)@2: number return target(a, b) end )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 24); + ac = autocomplete('2'); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::None); } @@ -1644,10 +1659,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion") check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1655,10 +1670,10 @@ local x = target(function(a: check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n +local x = target(function(a: n@1 )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1666,17 +1681,17 @@ local x = target(function(a: n check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n, b: ) +local x = target(function(a: n@1, b: @2) return a + #b end) )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1684,12 +1699,12 @@ end) check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a: n) +local x = target(function(a: n@1) return a end )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1700,12 +1715,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(...:n) +local x = target(function(...:n@1) return a end )"); - auto ac = autocomplete(3, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1713,12 +1728,12 @@ end check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a:number, b:number, ...:) +local x = target(function(a:number, b:number, ...:@1) return a + b end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1729,12 +1744,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") check(R"( local function target(callback: () -> number) return callback() end -local x = target(function(): n +local x = target(function(): n@1 return 1 end )"); - auto ac = autocomplete(3, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1742,12 +1757,12 @@ end check(R"( local function target(callback: () -> (number, number)) return callback() end -local x = target(function(): (number, n +local x = target(function(): (number, n@1 return 1, 2 end )"); - ac = autocomplete(3, 39); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1758,12 +1773,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): ...n +local x = target(function(): ...n@1 return 1, 2, 3 end )"); - auto ac = autocomplete(3, 33); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1771,12 +1786,12 @@ end check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): (number, number, ...n +local x = target(function(): (number, number, ...n@1 return 1, 2, 3 end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1787,10 +1802,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_opt check(R"( local function target(callback: nil | (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1803,21 +1818,21 @@ local t = {} t.x = 5 function t:target(callback: (a: number, b: string) -> number) return callback(self.x, "hello") end -local x = t:target(function(a: , b: ) end) -local y = t.target(t, function(a: number, b: ) end) +local x = t:target(function(a: @1, b:@2 ) end) +local y = t.target(t, function(a: number, b: @3) end) )"); - auto ac = autocomplete(5, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(5, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(6, 45); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1899,26 +1914,26 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_synthetic_table_name") { check(R"( local foo = { a = 1, b = 2 } -local bar: = foo +local bar: @1= foo )"); - auto ac = autocomplete(2, 11); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_function_no_parenthesis") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") { check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(frontend, "MainModule", Position{5, 15}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1930,16 +1945,16 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( local function f(a: { x: number, y: number }) return a.x + a.y end -local fp: = f +local fp: @1= f )"); - auto ac = autocomplete(2, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_keywords") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") { check(R"( local function a(x: boolean) end @@ -1951,33 +1966,33 @@ local function e(x: ((number) -> string) & ((boolean) -> number)) end local tru = {} local ni = false -local ac = a(t) -local bc = b(n) -local cc = c(f) -local dc = d(f) -local ec = e(f) +local ac = a(t@1) +local bc = b(n@2) +local cc = c(f@3) +local dc = d(f@4) +local ec = e(f@5) )"); - auto ac = autocomplete(frontend, "MainModule", Position{10, 14}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{11, 14}, nullCallback); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{12, 14}, nullCallback); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{13, 14}, nullCallback); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{14, 14}, nullCallback); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -1988,10 +2003,10 @@ local target: ((number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2002,10 +2017,10 @@ local target: ((number) -> string) & ((number) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2016,10 +2031,10 @@ local target: ((number, number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(1, o) +return target(1, o@1) )"); - ac = autocomplete(5, 18); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2032,10 +2047,10 @@ TEST_CASE_FIXTURE(ACFixture, "optional_members") local a = { x = 2, y = 3 } type A = typeof(a) local b: A? = a -return b. +return b.@1 )"); - auto ac = autocomplete(4, 9); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2045,10 +2060,10 @@ return b. local a = { x = 2, y = 3 } type A = typeof(a) local b: nil | A = a -return b. +return b.@1 )"); - ac = autocomplete(4, 9); + ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2056,10 +2071,10 @@ return b. check(R"( local b: nil | nil -return b. +return b.@1 )"); - ac = autocomplete(2, 9); + ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -2067,26 +2082,26 @@ return b. TEST_CASE_FIXTURE(ACFixture, "no_function_name_suggestions") { check(R"( -function na +function na@1 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function +local function @1 )"); - ac = autocomplete(1, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function na +local function na@1 )"); - ac = autocomplete(1, 17); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -2095,20 +2110,20 @@ TEST_CASE_FIXTURE(ACFixture, "skip_current_local") { check(R"( local other = 1 -local name = na +local name = na@1 )"); - auto ac = autocomplete(2, 15); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(ac.entryMap.count("other")); check(R"( local other = 1 -local name, test = na +local name, test = na@1 )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(!ac.entryMap.count("test")); @@ -2119,26 +2134,26 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_members") { check(R"( local a = { done = 1, forever = 2 } -local b = a.do -local c = a.for -local d = a. +local b = a.do@1 +local c = a.for@2 +local d = a.@3 do end )"); - auto ac = autocomplete(2, 14); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(3, 15); + ac = autocomplete('2'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(4, 12); + ac = autocomplete('3'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2150,10 +2165,10 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_methods") check(R"( local a = {} function a:done() end -local b = a:do +local b = a:do@1 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2247,29 +2262,29 @@ local elsewhere = false local doover = false local endurance = true -if 1 then -else +if 1 then@1 +else@2 end -while false do +while false do@3 end -repeat +repeat@4 until )"); - auto ac = autocomplete(6, 9); + auto ac = autocomplete('1'); CHECK(ac.entryMap.size() == 1); CHECK(ac.entryMap.count("then")); - ac = autocomplete(7, 4); + ac = autocomplete('2'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); - ac = autocomplete(10, 14); + ac = autocomplete('3'); CHECK(ac.entryMap.count("do")); - ac = autocomplete(13, 6); + ac = autocomplete('4'); CHECK(ac.entryMap.count("do")); // FIXME: ideally we want to handle start and end of all statements as well @@ -2284,11 +2299,11 @@ local elsewhere = false if true then return 1 -el +el@1 end )"); - auto ac = autocomplete(5, 2); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere") == 0); @@ -2300,11 +2315,11 @@ if true then return 1 else return 2 -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); CHECK(ac.entryMap.count("elsewhere")); @@ -2316,10 +2331,10 @@ if true then print("1") elif true then print("2") -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere")); @@ -2360,30 +2375,30 @@ TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { check(R"( type Test = { first: number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - auto ac = autocomplete(2, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Intersection check(R"( type Test = { first: number } & { second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Union check(R"( type Test = { first: number, second: number } | { second: number, third: number } -local t: Test = { s } +local t: Test = { s@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("third")); @@ -2391,60 +2406,60 @@ local t: Test = { s } // No parenthesis suggestion check(R"( type Test = { first: (number) -> number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap["first"].parens == ParenthesesRecommendation::None); // When key is changed check(R"( type Test = { first: number, second: number } -local t: Test = { f = 2 } +local t: Test = { f@1 = 2 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { ["f"] } +local t: Test = { ["f@1"] } )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Not an alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { "f" } +local t: Test = { "f@1" } )"); - ac = autocomplete(2, 20); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("second")); // Skip keys that are already defined check(R"( type Test = { first: number, second: number } -local t: Test = { first = 2, s } +local t: Test = { first = 2, s@1 } )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Don't skip active key check(R"( type Test = { first: number, second: number } -local t: Test = { first } +local t: Test = { first@1 } )"); - ac = autocomplete(2, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); @@ -2452,22 +2467,22 @@ local t: Test = { first } check(R"( local t = { { first = 5, second = 10 }, - { f } + { f@1 } } )"); - ac = autocomplete(3, 7); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); check(R"( local t = { [2] = { first = 5, second = 10 }, - [5] = { f } + [5] = { f@1 } } )"); - ac = autocomplete(3, 13); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); } @@ -2502,15 +2517,15 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") local temp = false local even = true; local a = true -a = if t1@emp then t -a = if temp t2@ -a = if temp then e3@ -a = if temp then even e4@ -a = if temp then even elseif t5@ -a = if temp then even elseif true t6@ -a = if temp then even elseif true then t7@ -a = if temp then even elseif true then temp e8@ -a = if temp then even elseif true then temp else e9@ +a = if t@1emp then t +a = if temp t@2 +a = if temp then e@3 +a = if temp then even e@4 +a = if temp then even elseif t@5 +a = if temp then even elseif true t@6 +a = if temp then even elseif true then t@7 +a = if temp then even elseif true then temp e@8 +a = if temp then even elseif true then temp else e@9 )"); auto ac = autocomplete('1'); @@ -2573,4 +2588,20 @@ a = if temp then even elseif true then temp else e9@ } } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + check(R"( +type A = () -> T... +local a: A<(number, s@1> + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 26bc77f70..29c33f7c1 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -32,6 +32,55 @@ std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const return std::nullopt; } +std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) +{ + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "game") + return ModuleInfo{"game"}; + if (g->name == "workspace") + return ModuleInfo{"workspace"}; + if (g->name == "script") + return context ? std::optional(*context) : std::nullopt; + } + else if (AstExprIndexName* i = expr->as(); i && context) + { + if (i->index == "Parent") + { + std::string_view view = context->name; + size_t lastSeparatorIndex = view.find_last_of('/'); + + if (lastSeparatorIndex == std::string_view::npos) + return std::nullopt; + + return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional}; + } + else + { + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + } + else if (AstExprIndexExpr* i = expr->as(); i && context) + { + if (AstExprConstantString* index = i->index->as()) + { + return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional}; + } + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; +} + ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const { return lhs + "/" + ModuleName(rhs); diff --git a/tests/Fixture.h b/tests/Fixture.h index c6294b014..1480a7f6a 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -65,6 +65,8 @@ struct TestFileResolver } std::optional fromAstFragment(AstExpr* expr) const override; + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; std::optional getParentModuleName(const ModuleName& name) const override; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3f33a5d19..fbfec6367 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -58,6 +58,35 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override + { + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "Modules") + return ModuleInfo{"Modules"}; + + if (g->name == "game") + return ModuleInfo{"game"}; + } + else if (AstExprIndexName* i = expr->as()) + { + if (context) + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; + } + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override { return lhs + "/" + ModuleName(rhs); @@ -528,7 +557,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file") { fileResolver.source["Modules/A"] = R"( local Modules = script - local B = require(Modules.B :: any) + local B = require(Modules.B) :: any )"; CheckResult result = frontend.check("Modules/A"); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c8eff3991..a9ed139f1 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1400,6 +1400,8 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { + ScopedFastFlag sff("LuauLinterTableMoveZero", true); + LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1417,9 +1419,12 @@ table.remove(t, 0) table.remove(t, #t-1) table.insert(t, string.find("hello", "h")) + +table.move(t, 0, #t, 1, tt) +table.move(t, 1, #t, 0, tt) )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE_EQ(result.warnings.size(), 8); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1429,6 +1434,8 @@ table.insert(t, string.find("hello", "h")) "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[5].text, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); + CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 1b146ed2a..18f55d2c1 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Fixture.h" diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index f3c76d55a..931a8403a 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index cb03a7bd9..a80718e47 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2519,4 +2519,19 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index cbd4af29a..b9fd04d69 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") { AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz + require(m) )"); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") AstExprIndexName* value = loc->values.data[0]->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo", result.exprs[value]); + CHECK_EQ("workspace/Foo", result.exprs[value].name); AstExprGlobal* workspace = value->expr->as(); REQUIRE(workspace); REQUIRE(result.exprs.contains(workspace)); - CHECK_EQ("workspace", result.exprs[workspace]); + CHECK_EQ("workspace", result.exprs[workspace].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") @@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz local n = m.Quux + require(n) )"); - REQUIRE_EQ(2, block->body.size); + REQUIRE_EQ(3, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") REQUIRE_EQ(1, local->vars.size); REQUIRE(result.exprs.contains(local->values.data[0])); - CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]); + CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") { AstStatBlock* block = parse(R"( - local M = require(workspace.Game.Thing, workspace.Something.Else) + local M = require(workspace.Game.Thing) )"); REQUIRE_EQ(1, block->body.size); @@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") AstExprCall* call = local->values.data[0]->as(); REQUIRE(call != nullptr); - REQUIRE_EQ(2, call->args.size); - - CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]); - CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls") -{ - AstStatBlock* block = parse(R"( - local R = game:GetService('ReplicatedStorage').Roact - local Roact = require(R) - )"); - REQUIRE_EQ(2, block->body.size); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - AstStatLocal* local = block->body.data[0]->as(); - REQUIRE(local != nullptr); - - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]); - - AstStatLocal* local2 = block->body.data[1]->as(); - REQUIRE(local2 != nullptr); - REQUIRE_EQ(1, local2->values.size); - - AstExprCall* call = local2->values.data[0]->as(); - REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls") -{ - ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true); - - AstStatBlock* block = parse(R"( -local A = require(workspace:WaitForChild('ReplicatedStorage').Content) -local B = require(workspace:FindFirstChild('ReplicatedFirst').Data) - )"); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - REQUIRE_EQ(2, result.requires.size()); - CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first); - CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first); + CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") @@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]); + CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") { AstStatBlock* block = parse(R"( local R = game["Test"] + require(R) )"); - REQUIRE_EQ(1, block->body.size); + REQUIRE_EQ(2, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); AstStatLocal* local = block->body.data[0]->as(); REQUIRE(local != nullptr); - CHECK_EQ("game/Test", result.exprs[local->values.data[0]]); + CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7d68c461..e18bf7cdd 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Fixture.h" @@ -416,8 +417,6 @@ function foo(a, b) return a(b) end TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") { - ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true}; - TypeVar tv1{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp new file mode 100644 index 000000000..045f0230d --- /dev/null +++ b/tests/TypeInfer.aliases.test.cpp @@ -0,0 +1,557 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeAliases"); + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type F = () -> F? + local function f() + return f + end + + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") +{ + CheckResult result = check(R"( + --!strict + type Node = { Parent: Node?; } + local node: Node; + node.Parent = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Node?", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: b, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_ERRORS(result); + + // We had a UAF in this example caused by not cloning type function arguments + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + module->internalTypes.clear(); + module->astTypes.clear(); + + // Make sure the error strings don't include "VALUELESS" + for (auto error : module->errors) + CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); +} + +TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") +{ + CheckResult result = check(R"( + type Pair = {first: T, second: U} + local a: Pair + local b: Pair + + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK_EQ("Pair", toString(tm->wantedType)); + CHECK_EQ("Pair", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") +{ + CheckResult result = check(R"( + type A = number + type A = string -- Redefinition of type 'A', previously defined at line 1 + local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = Table + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Wrapped", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = (Table) -> string + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +// Check that recursive intersection type doesn't generate an OOM +TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") +{ + CheckResult result = check(R"( + function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any + end + type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) + _(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") +{ + CheckResult result = check(R"( + local foo: Id = 1 + type Id = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") +{ + const std::string code = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb = aa + )"; + + const std::string expected = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type A = () -> (number, B) + type B = () -> (string, A) + local a: A + local b: B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "generic_param_remap") +{ + const std::string code = R"( + -- An example of a forwarded use of a type that has different type arguments than parameters + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb = aa + )"; + + const std::string expected = R"( + + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + export type Foo = number + type Foo = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "Foo"); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") +{ + ScopedFastFlag sffs3{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node?) + local a: Node + + if node then + a = node.child -- Observe the output of the error message. + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + CHECK_EQ("Node?", toString(e->givenType)); + CHECK_EQ("Node", toString(e->wantedType)); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +{ + fileResolver.source["workspace/A"] = R"( + export type myvec2 = {x: number, y: number} + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + export type myvec3 = {x: number, y: number, z: number} + return {} + )"; + + fileResolver.source["workspace/C"] = R"( + local Foo, Bar = require(workspace.A), require(workspace.B) + + local a: Foo.myvec2 + local b: Bar.myvec3 + )"; + + CheckResult result = frontend.check("workspace/C"); + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; + + REQUIRE(m != nullptr); + + std::optional aTypeId = lookupName(m->getModuleScope(), "a"); + REQUIRE(aTypeId); + const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + REQUIRE(aType); + REQUIRE(aType->props.size() == 2); + + std::optional bTypeId = lookupName(m->getModuleScope(), "b"); + REQUIRE(bTypeId); + const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + REQUIRE(bType); + REQUIRE(bType->props.size() == 3); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +{ + CheckResult result = check("type t10 = typeof(table)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); + CHECK_EQ(toString(ty), "table"); + + const TableTypeVar* ttv = get(ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +local c: Cool = { a = 1, b = "s" } +type NotCool = Cool +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +type NotCool = Cool +local c: Cool = { a = 1, b = "s" } +local d: NotCool = { a = 1, b = "s" } +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + ty = requireType("d"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "NotCool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") +{ + CheckResult result = check(R"( +local c = { a = 1, b = "s" } +type Cool = typeof(c) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK_EQ(ttv->name, "Cool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: number, b: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(follow(*ty1), follow(*ty2)); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: T, b: U, C: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); + + bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return function(obj) return true end +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return {a = 1, b = function(obj) return true end} +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: Forest } + type Forest = {Tree} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- OK because forwarded types are used with their parameters. + type Tree = { data: T, children: Forest } + type Forest = {Tree<{T}>} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- Not OK because forwarded types are used with different types than their parameters. + type Forest = {Tree<{T}>} + type Tree = { data: T, children: Forest } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") +{ + CheckResult result = check(R"( + function f(x) return x[1] end + -- x has type X? for a free type variable X + local x = f ({}) + type ContainsFree = { this: a, that: typeof(x) } + type ContainsContainsFree = { that: ContainsFree } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") +{ + ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; + + CheckResult result = check(R"( + type Array = { [number]: T } + type Tuple = Array + + local p: Tuple + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{number | string}", toString(requireType("p"), {true})); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 46496fdb1..8bcb02424 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -30,6 +30,8 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + const std::string code = R"( function f(a) if type(a) == "boolean" then @@ -41,11 +43,11 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") )"; const std::string expected = R"( - function f(a:{fn:()->(free)}): () + function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a elseif a.fn()then - local a2:{fn:()->(free)}=a + local a2:{fn:()->(free,free...)}=a end end )"; @@ -231,16 +233,7 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") local r2 = b == a )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'"); - CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'"); - } + LUAU_REQUIRE_NO_ERRORS(result); } // Belongs in TypeInfer.refinements.test.cpp. @@ -542,6 +535,25 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct } } +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + // Really this should return an error, but it doesn't + LUAU_REQUIRE_NO_ERRORS(result); +} + // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f2ba0ddc5..31739cdc7 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Fixture.h" @@ -6,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) LUAU_FASTFLAG(LuauOrPredicate) using namespace Luau; @@ -199,16 +199,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x) if type(x) == "vector" then @@ -544,8 +534,6 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local t = {"hello"} local v = t[2] @@ -573,8 +561,6 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | boolean) if type(x) ~= "string" then @@ -593,8 +579,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) if type(x) == "table" then @@ -613,8 +597,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) if type(x) == "function" then @@ -698,8 +680,6 @@ struct RefinementClassFixture : Fixture TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(vec) local X, Y, Z = vec.X, vec.Y, vec.Z @@ -726,8 +706,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: Instance | Vector3) if typeof(x) == "Vector3" then @@ -746,8 +724,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | Instance | Vector3) if type(x) == "userdata" then @@ -766,10 +742,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | string) @@ -789,10 +762,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) @@ -812,10 +782,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( --!nonstrict @@ -839,7 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { ScopedFastFlag sffs[] = { {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, {"LuauTypeGuardPeelsAwaySubclasses", true}, }; @@ -861,8 +827,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type XYCoord = {x: number} & {y: number} local function f(t: XYCoord?) @@ -882,8 +846,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) local function f(g: SomeOverloadedFunction?) @@ -903,8 +865,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(t: {x: number}) if type(t) ~= "table" then @@ -999,10 +959,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(x: any) @@ -1036,10 +993,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local a: (number | string)? @@ -1057,10 +1011,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -1081,10 +1032,7 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(a: string | number | boolean) diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 1d1b2fae2..b7f0dc7b0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -46,6 +46,21 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") CHECK(tType->props.find("foo") != tType->props.end()); } +TEST_CASE_FIXTURE(Fixture, "augment_nested_table") +{ + CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* tType = getMutable(requireType("t")); + REQUIRE(tType != nullptr); + + REQUIRE(tType->props.find("p") != tType->props.end()); + const TableTypeVar* pType = get(tType->props["p"].type); + REQUIRE(pType != nullptr); + + CHECK(pType->props.find("foo") != pType->props.end()); +} + TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); @@ -260,6 +275,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local a = {} a.x = 99 @@ -272,10 +289,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ(error->key, "y"); + CHECK_EQ("y", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time // CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25})); @@ -328,6 +346,8 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( --!strict function foo(o) @@ -340,14 +360,17 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") LUAU_REQUIRE_ERROR_COUNT(1, result); - UnknownProperty* error = get(result.errors[0]); + MissingProperties* error = get(result.errors[0]); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ("baz", error->key); + CHECK_EQ("baz", error->properties[0]); } TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -359,8 +382,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("baz", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time /* @@ -448,6 +474,73 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local t = { u = {} } + t = { u = { p = 37 } } + t = { u = { q = "hi" } } + local x = t.u.p + local y = t.u.q + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("x"))); + CHECK_EQ("string?", toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") +{ + CheckResult result = check(R"( + --!strict + function get(x) return x.opts["MYOPT"] end + function set(x,y) x.opts["MYOPT"] = y end + local t = { opts = {} } + set(t,37) + local x = get(t) + )"); + + // Currently this errors but it shouldn't, since set only needs write access + // TODO: file a JIRA for this + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + function f(x : { q : number }) + x.q = 8 + end + local t : { q : number, r : string } = { q = 8, r = "hi" } + f(t) + local x : string = t.r + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance") +{ + CheckResult result = check(R"( + --!strict + function f(x : { p : { q : number }}) + x.p = { q = 8, r = 5 } + end + local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } } + f(t) -- Shouldn't typecheck + local x : string = t.p.r -- x is 5 + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "infer_array") { CheckResult result = check(R"( @@ -676,16 +769,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer") +TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); +} - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm != nullptr); +TEST_CASE_FIXTURE(Fixture, "array_factory_function") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + function empty() return {} end + local array: {string} = empty() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") @@ -756,37 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t: { a: string } - function f(x: string) return t[x] end - local a = f("a") - local b = f("b") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.anyType, *requireType("a")); - CHECK_EQ(*typeChecker.anyType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t = { a = true } - function f(x: number) return t[x] end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); -} - TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -1392,6 +1465,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1402,8 +1477,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("string", toString(tm->wantedType, o)); - CHECK_EQ("number", toString(tm->givenType, o)); + CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}, i: string) + return a[i] + end + local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" + )"); + + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") @@ -1446,22 +1534,32 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {x = 1, y = 2, z = 3} - local vec1 = {x = 1} + local vec3 = {{x = 1, y = 2, z = 3}} + local vec1 = {{x = 1}} vec1 = vec3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Extra); - REQUIRE_EQ(2, mp->properties.size()); - CHECK_EQ(mp->properties[0], "y"); - CHECK_EQ(mp->properties[1], "z"); - CHECK_EQ("vec1", toString(mp->superType)); - CHECK_EQ("vec3", toString(mp->subType)); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("vec1", toString(tm->wantedType)); + CHECK_EQ("vec3", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec1 = vec3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") @@ -1824,4 +1922,32 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 37333b193..b75878b7e 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" @@ -978,23 +979,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type F = () -> F? - local function f() - return f - end - - local g: F = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); -} - // TODO: File a Jira about this /* TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") @@ -1257,23 +1241,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") REQUIRE_EQ(follow(*methodArg), follow(arg)); } -TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") -{ - CheckResult result = check(R"( - --!strict - type Node = { Parent: Node?; } - local node: Node; - node.Parent = 1 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( @@ -2591,48 +2558,6 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: a, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = "lo", i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: b, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = 5, i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_ERRORS(result); - - // We had a UAF in this example caused by not cloning type function arguments - ModulePtr module = frontend.moduleResolver.getModule("MainModule"); - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - module->internalTypes.clear(); - module->astTypes.clear(); - - // Make sure the error strings don't include "VALUELESS" - for (auto error : module->errors) - CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); -} - TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 @@ -3369,16 +3294,7 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") end )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") @@ -3388,18 +3304,8 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable print((x == true and (x .. "y")) .. 1) )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0])); - CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") @@ -3511,25 +3417,6 @@ _(...)(...,setfenv,_):_G() )"); } -TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") -{ - CheckResult result = check(R"( - type Pair = {first: T, second: U} - local a: Pair - local b: Pair - - a = b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - - CHECK_EQ("Pair", toString(tm->wantedType)); - CHECK_EQ("Pair", toString(tm->givenType)); -} - TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") { // this has a risk of creating cyclic type packs, causing infinite loops / OOMs @@ -3639,17 +3526,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") -{ - CheckResult result = check(R"( - type A = number - type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { CheckResult result = check(R"( @@ -3752,38 +3628,6 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); } -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = Table - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = (Table) -> string - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") { CheckResult result = check(R"( @@ -3909,19 +3753,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -// Check that recursive intersection type doesn't generate an OOM -TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") -{ - CheckResult result = check(R"( - function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any - end - type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) - _(_) - )"); - - LUAU_REQUIRE_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed @@ -3974,16 +3805,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") -{ - CheckResult result = check(R"( - local foo: Id = 1 - type Id = T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { CheckResult result = check(R"( @@ -4014,81 +3835,6 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") -{ - const std::string code = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb = aa - )"; - - const std::string expected = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type A = () -> (number, B) - type B = () -> (string, A) - local a: A - local b: B - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); - CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "generic_param_remap") -{ - const std::string code = R"( - -- An example of a forwarded use of a type that has different type arguments than parameters - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb = aa - )"; - - const std::string expected = R"( - - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") -{ - CheckResult result = check(R"( - export type Foo = number - type Foo = number - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto dtd = get(result.errors[0]); - REQUIRE(dtd); - CHECK_EQ(dtd->name, "Foo"); -} - TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( @@ -4193,30 +3939,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } -TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") -{ - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - - CheckResult result = check(R"( - type Node = { value: T, child: Node? } - - local function visitor(node: Node?) - local a: Node - - if node then - a = node.child -- Observe the output of the error message. - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto e = get(result.errors[0]); - CHECK_EQ("Node?", toString(e->givenType)); - CHECK_EQ("Node", toString(e->wantedType)); -} - TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") { CheckResult result = check(R"( @@ -4272,181 +3994,6 @@ local tbl: string = require(game.A) CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") -{ - fileResolver.source["workspace/A"] = R"( - export type myvec2 = {x: number, y: number} - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - export type myvec3 = {x: number, y: number, z: number} - return {} - )"; - - fileResolver.source["workspace/C"] = R"( - local Foo, Bar = require(workspace.A), require(workspace.B) - - local a: Foo.myvec2 - local b: Bar.myvec3 - )"; - - CheckResult result = frontend.check("workspace/C"); - LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; - - REQUIRE(m != nullptr); - - std::optional aTypeId = lookupName(m->getModuleScope(), "a"); - REQUIRE(aTypeId); - const Luau::TableTypeVar* aType = get(follow(*aTypeId)); - REQUIRE(aType); - REQUIRE(aType->props.size() == 2); - - std::optional bTypeId = lookupName(m->getModuleScope(), "b"); - REQUIRE(bTypeId); - const Luau::TableTypeVar* bType = get(follow(*bTypeId)); - REQUIRE(bType); - REQUIRE(bType->props.size() == 3); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") -{ - CheckResult result = check("type t10 = typeof(table)"); - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); - CHECK_EQ(toString(ty), "table"); - - const TableTypeVar* ttv = get(ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -local c: Cool = { a = 1, b = "s" } -type NotCool = Cool -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -type NotCool = Cool -local c: Cool = { a = 1, b = "s" } -local d: NotCool = { a = 1, b = "s" } -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - ty = requireType("d"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "NotCool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") -{ - CheckResult result = check(R"( -local c = { a = 1, b = "s" } -type Cool = typeof(c) -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK_EQ(ttv->name, "Cool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: number, b: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(follow(*ty1), follow(*ty2)); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: T, b: U, C: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); - - bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); - CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); -} - TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( @@ -4560,32 +4107,6 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return function(obj) return true end -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return {a = 1, b = function(obj) return true end} -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "custom_require_global") { CheckResult result = check(R"( @@ -4768,8 +4289,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") { - ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true}; - fileResolver.source["Module/Backend/Types"] = R"( export type Fiber = { return_: Fiber? @@ -4849,8 +4368,8 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") ModulePtr module = getMainModule(); auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it != module->astOverloadResolvedTypes.end()); - CHECK_EQ(toString(it->second), "(number) -> number"); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -5013,76 +4532,6 @@ g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: Forest } - type Forest = {Tree} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- OK because forwarded types are used with their parameters. - type Tree = { data: T, children: Forest } - type Forest = {Tree<{T}>} - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- Not OK because forwarded types are used with different types than their parameters. - type Forest = {Tree<{T}>} - type Tree = { data: T, children: Forest } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") -{ - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") -{ - CheckResult result = check(R"( - function f(x) return x[1] end - -- x has type X? for a free type variable X - local x = f ({}) - type ContainsFree = { this: a, that: typeof(x) } - type ContainsContainsFree = { that: ContainsFree } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 91ac9f062..1f4b63ef2 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5f7f28474..3e1dedd47 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -294,4 +294,370 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T...) -> T... +local a: Packed<> +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "(T...) -> (T...)"); + CHECK_EQ(toString(requireType("a")), "() -> ()"); + CHECK_EQ(toString(requireType("b")), "(number) -> number"); + CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)"); + + result = check(R"( +-- (U..., T) cannot be parsed right now +type Packed = { f: (a: T, U...) -> (T, U...) } +local a: Packed +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); + + auto ttvA = get(requireType("a")); + REQUIRE(ttvA); + CHECK_EQ(toString(requireType("a")), "Packed"); + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + REQUIRE(ttvA->instantiatedTypeParams.size() == 1); + REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); + CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), ""); + + auto ttvB = get(requireType("b")); + REQUIRE(ttvB); + CHECK_EQ(toString(requireType("b")), "Packed"); + CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); + REQUIRE(ttvB->instantiatedTypeParams.size() == 1); + REQUIRE(ttvB->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number"); + + auto ttvC = get(requireType("c")); + REQUIRE(ttvC); + CHECK_EQ(toString(requireType("c")), "Packed"); + CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); + REQUIRE(ttvC->instantiatedTypeParams.size() == 1); + REQUIRE(ttvC->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +local a: Import.Packed +local b: Import.Packed +local c: Import.Packed +local d: { a: typeof(c) } + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + auto tf = lookupImportedType("Import", "Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); + CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); + CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult cResult = check(R"( +local Import = require(game.A) +type Alias = Import.Packed +local a: Alias + +type B = Import.Packed +type C = Import.Packed + )"); + LUAU_REQUIRE_NO_ERRORS(cResult); + + auto tf = lookupType("Alias"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Alias"); + CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + + tf = lookupType("B"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "B"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); + + tf = lookupType("C"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "C"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed1 = (T...) -> (T...) +type Packed2 = (Packed1, T...) -> (Packed1, T...) +type Packed3 = (Packed2, T...) -> (Packed2, T...) +type Packed4 = (Packed3, T...) -> (Packed3, T...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed4"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> " + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (string, T...) + +type D = X<...number> +type E = X<(number, ...string)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)"); + CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) +type A = Y +type B = Y<(number, ...string), S...> + +type Z = (T) -> (U...) +type E = Z +type F = Z + +type W = (T, U...) -> (T, V...) +type H = W +type I = W + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)"); + + CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)"); + CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)"); + + CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)"); + CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (T...) + +type A = X<(S...)> +type B = X<()> +type C = X<(number)> +type D = X<(number, string)> +type E = X<(...number)> +type F = X<(string, ...number)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> number"); + CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)"); + CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)"); + CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) + +type A = Y<(number, string), (boolean)> +type B = Y<(), ()> +type C = Y<...string, (number, S...)> +type D = Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)"); + CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block + + CheckResult result = check(R"( +type Y = { f: (T...) -> (U...) } + +local a: Y<(number, string), (boolean)> +local b: Y<(), ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>"); + CHECK_EQ(toString(requireType("b")), "Y<(), ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = () -> T +type Y = (T) -> U + +type A = X<(number)> +type B = Y<(number), (boolean)> +type C = Y<(number), boolean> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "() -> number"); + CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T, U) -> (V...) +local b: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects at least 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T, U) -> () +type B = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 0 type pack arguments, but 1 is specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters"); + + result = check(R"( +type Packed = (T) -> U +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T...) -> T... +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but none are specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ae4d836b4..037144e2a 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -237,21 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") local z = a == c )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.booleanType, *requireType("x")); - CHECK_EQ(*typeChecker.booleanType, *requireType("y")); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(number | string)?", toString(*tm->wantedType)); - CHECK_EQ("boolean | number", toString(*tm->givenType)); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "optional_union_members") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 98ce9f939..a679e3fd2 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tools/tracegraph.py b/tools/tracegraph.py new file mode 100644 index 000000000..a46423e7e --- /dev/null +++ b/tools/tracegraph.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file +# The result of analysis is a .svg file which can be viewed in a browser + +import sys +import svg +import json + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.caption = "" + self.description = "" + self.ticks = 0 + + def text(self): + return self.caption + + def title(self): + return self.caption + + def details(self, root): + return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks) + +with open(sys.argv[1]) as f: + dump = f.read() + +root = Node() + +# Finish the file +if not dump.endswith("]"): + dump += "{}]" + +data = json.loads(dump) + +stacks = {} + +for l in data: + if len(l) == 0: + continue + + # Track stack of each thread, but aggregate values together + tid = l["tid"] + + if not tid in stacks: + stacks[tid] = [] + stack = stacks[tid] + + if l["ph"] == 'B': + stack.append(l) + elif l["ph"] == 'E': + node = root + + for e in stack: + caption = e["name"] + description = '' + + if "args" in e: + for arg in e["args"]: + if len(description) != 0: + description += ", " + + description += "{}: {}".format(arg, e["args"][arg]) + + child = node.child(caption + description) + child.caption = caption + child.description = description + + node = child + + begin = stack[-1] + + ticks = l["ts"] - begin["ts"] + rawticks = ticks + + # Flame graph requires ticks without children duration + if "childts" in begin: + ticks -= begin["childts"] + + node.ticks += int(ticks) + + stack.pop() + + if len(stack): + parent = stack[-1] + + if "childts" in parent: + parent["childts"] += rawticks + else: + parent["childts"] = rawticks + +svg.layout(root, lambda n: n.ticks) +svg.display(root, "Flame Graph", "hot", flip = True) From 34cf695fbc35eb435dcd9fb85c3b98234fdd266c Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 4 Nov 2021 19:42:00 -0700 Subject: [PATCH 02/14] Sync to upstream/release/503 - A series of major optimizations to type checking performance on complex programs/types (up to two orders of magnitude speedup for programs involving huge tagged unions) - Fix a few issues encountered by UBSAN (and maybe fix s390x builds) - Fix gcc-11 test builds - Fix a rare corner case where luau_load wouldn't wake inactive threads which could result in a use-after-free due to GC - Fix CLI crash when error object that's not a string escapes to top level --- Analysis/include/Luau/BuiltinDefinitions.h | 1 - Analysis/include/Luau/Quantify.h | 14 + Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/TopoSortStatements.h | 1 + Analysis/include/Luau/TxnLog.h | 28 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeVar.h | 9 +- Analysis/include/Luau/Unifier.h | 20 +- Analysis/include/Luau/UnifierSharedState.h | 44 ++ Analysis/include/Luau/VisitTypeVar.h | 59 +- Analysis/src/Autocomplete.cpp | 5 +- Analysis/src/BuiltinDefinitions.cpp | 12 - Analysis/src/Error.cpp | 21 +- Analysis/src/Frontend.cpp | 11 +- Analysis/src/Module.cpp | 18 +- Analysis/src/Quantify.cpp | 90 +++ Analysis/src/RequireTracer.cpp | 6 +- Analysis/src/ToString.cpp | 25 +- Analysis/src/TopoSortStatements.cpp | 25 + Analysis/src/TxnLog.cpp | 37 +- Analysis/src/TypeAttach.cpp | 93 ++- Analysis/src/TypeInfer.cpp | 199 ++--- Analysis/src/TypeVar.cpp | 86 ++- Analysis/src/Unifier.cpp | 349 ++++++++- Ast/include/Luau/TimeTrace.h | 16 +- Ast/src/Parser.cpp | 2 +- Ast/src/TimeTrace.cpp | 5 +- CLI/Repl.cpp | 24 +- Compiler/include/Luau/Bytecode.h | 4 +- Compiler/src/Compiler.cpp | 14 +- Makefile | 2 + Sources.cmake | 3 + VM/include/lualib.h | 2 +- VM/src/lapi.cpp | 4 +- VM/src/laux.cpp | 2 +- VM/src/lcorolib.cpp | 72 +- VM/src/ldo.cpp | 14 +- VM/src/ldo.h | 2 +- VM/src/lfunc.cpp | 2 +- VM/src/lgc.cpp | 331 ++++---- VM/src/lgc.h | 8 +- VM/src/lmem.cpp | 2 +- VM/src/lstring.cpp | 2 +- VM/src/lvmload.cpp | 4 +- bench/tests/chess.lua | 849 +++++++++++++++++++++ bench/tests/shootout/scimark.lua | 2 +- tests/Autocomplete.test.cpp | 24 +- tests/Compiler.test.cpp | 84 +- tests/Conformance.test.cpp | 13 + tests/IostreamOptional.h | 7 +- tests/Linter.test.cpp | 14 +- tests/TypeInfer.aliases.test.cpp | 50 ++ tests/TypeInfer.builtins.test.cpp | 2 +- tests/TypeInfer.classes.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 21 + tests/TypeInfer.provisional.test.cpp | 25 +- tests/TypeInfer.refinements.test.cpp | 11 +- tests/TypeInfer.tables.test.cpp | 2 +- tests/TypeInfer.test.cpp | 62 +- tests/TypeInfer.tryUnify.test.cpp | 10 +- tests/TypeInfer.typePacks.cpp | 6 +- tests/TypeInfer.unionTypes.test.cpp | 25 +- tests/TypeVar.test.cpp | 60 ++ tests/conformance/closure.lua | 2 +- tests/conformance/coroutine.lua | 2 +- tests/conformance/gc.lua | 2 +- tests/conformance/locals.lua | 2 +- tests/conformance/math.lua | 2 +- tests/conformance/pm.lua | 4 +- tests/conformance/tmerror.lua | 15 + tools/gdb-printers.py | 8 +- 71 files changed, 2300 insertions(+), 683 deletions(-) create mode 100644 Analysis/include/Luau/Quantify.h create mode 100644 Analysis/include/Luau/UnifierSharedState.h create mode 100644 Analysis/src/Quantify.cpp create mode 100644 bench/tests/chess.lua create mode 100644 tests/conformance/tmerror.lua diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 57a1907a5..07d897b2f 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachFunctionTag(TypeId ty, std::string constraint); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h new file mode 100644 index 000000000..f46df1460 --- /dev/null +++ b/Analysis/include/Luau/Quantify.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +namespace Luau +{ + +struct Module; +using ModulePtr = std::shared_ptr; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 0897ec854..e5683fc40 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); void dump(TypeId ty); void dump(TypePackId ty); +std::string generateName(size_t n); + } // namespace Luau diff --git a/Analysis/include/Luau/TopoSortStatements.h b/Analysis/include/Luau/TopoSortStatements.h index 751694f02..4a4acfa3a 100644 --- a/Analysis/include/Luau/TopoSortStatements.h +++ b/Analysis/include/Luau/TopoSortStatements.h @@ -12,6 +12,7 @@ struct AstArray; class AstStat; bool containsFunctionCall(const AstStat& stat); +bool containsFunctionCallOrReturn(const AstStat& stat); bool isFunction(const AstStat& stat); void toposort(std::vector& stats); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 055441ce8..322abd198 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -3,17 +3,35 @@ #include "Luau/TypeVar.h" +LUAU_FASTFLAG(LuauShareTxnSeen); + namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be struct TxnLog { - TxnLog() = default; + TxnLog() + : originalSeenSize(0) + , ownedSeen() + , sharedSeen(&ownedSeen) + { + } + + explicit TxnLog(std::vector>* sharedSeen) + : originalSeenSize(sharedSeen->size()) + , ownedSeen() + , sharedSeen(sharedSeen) + { + } - explicit TxnLog(const std::vector>& seen) - : seen(seen) + explicit TxnLog(const std::vector>& ownedSeen) + : originalSeenSize(ownedSeen.size()) + , ownedSeen(ownedSeen) + , sharedSeen(nullptr) { + // This is deprecated! + LUAU_ASSERT(!FFlag::LuauShareTxnSeen); } TxnLog(const TxnLog&) = delete; @@ -38,9 +56,11 @@ struct TxnLog std::vector> typeVarChanges; std::vector> typePackChanges; std::vector>> tableChanges; + size_t originalSeenSize; public: - std::vector> seen; // used to avoid infinite recursion when types are cyclic + std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic + std::vector>* sharedSeen; // shared with all the descendent logs }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d701eb248..9d62fef0b 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -11,6 +11,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" +#include "Luau/UnifierSharedState.h" #include #include @@ -121,7 +122,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); @@ -336,7 +337,7 @@ struct TypeChecker // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -383,6 +384,8 @@ struct TypeChecker std::function prepareModuleScope; InternalErrorReporter* iceHandler; + UnifierSharedState unifierState; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d4e4e4913..9611e881f 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -405,7 +405,7 @@ const std::string* getName(TypeId type); // Checks whether a union contains all types of another union. bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); -// Checks if a type conains generic type binders +// Checks if a type contains generic type binders bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders @@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); +void attachTag(TypeId ty, const std::string& tagName); +void attachTag(Property& prop, const std::string& tagName); + +bool hasTag(TypeId ty, const std::string& tagName); +bool hasTag(const Property& prop, const std::string& tagName); +bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. + } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 522914b2f..56632e33c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -6,6 +6,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/UnifierSharedState.h" #include @@ -41,11 +42,14 @@ struct Unifier std::shared_ptr counters_DEPRECATED; - InternalErrorReporter* iceHandler; + UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. @@ -69,7 +73,8 @@ struct Unifier void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); - TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + void cacheResult(TypeId superTy, TypeId subTy); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -101,8 +106,9 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - DenseHashSet tempSeenTy{nullptr}; - DenseHashSet tempSeenTp{nullptr}; + // Remove with FFlagLuauCacheUnifyTableResults + DenseHashSet tempSeenTy_DEPRECATED{nullptr}; + DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h new file mode 100644 index 000000000..f252a004b --- /dev/null +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -0,0 +1,44 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ +struct InternalErrorReporter; + +struct TypeIdPairHash +{ + size_t hashOne(Luau::TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +struct UnifierSharedState +{ + UnifierSharedState(InternalErrorReporter* iceHandler) + : iceHandler(iceHandler) + { + } + + InternalErrorReporter* iceHandler; + + DenseHashSet seenAny{nullptr}; + DenseHashMap skipCacheForType{nullptr}; + DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index df0bd4205..a866655c9 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/TypeVar.h" #include "Luau/TypePack.h" +LUAU_FASTFLAG(LuauCacheUnifyTableResults) + namespace Luau { @@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set& seen, const void* tv) return !seen.insert(ttv).second; } +inline bool hasSeen(DenseHashSet& seen, const void* tv) +{ + void* ttv = const_cast(tv); + + if (seen.contains(ttv)) + return true; + + seen.insert(ttv); + return false; +} + inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen); +inline void unsee(DenseHashSet& seen, const void* tv) +{ + // When DenseHashSet is used for 'visitOnce', where don't forget visited elements +} + +template +void visit(TypePackId tp, F& f, Set& seen); -template -void visit(TypeId ty, F& f, std::unordered_set& seen) +template +void visit(TypeId ty, F& f, Set& seen) { if (visit_detail::hasSeen(seen, ty)) { @@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) else if (auto ttv = get(ty)) { + // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) + if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); + visit(*ttv->boundTo, f, seen); + } + else + { + for (auto& [_name, prop] : ttv->props) + visit(prop.type, f, seen); + + if (ttv->indexer) + { + visit(ttv->indexer->indexType, f, seen); + visit(ttv->indexer->indexResultType, f, seen); + } } } } @@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) visit_detail::unsee(seen, ty); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen) +template +void visit(TypePackId tp, F& f, Set& seen) { if (visit_detail::hasSeen(seen, tp)) { @@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set& seen) visit_detail::unsee(seen, tp); } + } // namespace visit_detail template @@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f) visit_detail::visit(ty, f, seen); } +template +void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) +{ + seen.clear(); + visit_detail::visit(ty, f, seen); +} + } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 235abf36f..3c43c8086 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { InternalErrorReporter iceReporter; - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter); + UnifierSharedState unifierState(&iceReporter); + Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); unifier.tryUnify(expectedType, actualType); @@ -1460,7 +1461,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M result.erase(std::string(stringKey->value.data, stringKey->value.size)); } - // If we know for sure that a key is being written, do not offer general epxression suggestions + // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 3b0c21638..f6f2363c6 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } -void attachFunctionTag(TypeId ty, std::string tag) -{ - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(std::move(tag)); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } -} - Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 92fbffc80..04d91444a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -158,21 +158,20 @@ struct ErrorConverter std::string operator()(const Luau::CountMismatch& e) const { + const std::string expectedS = e.expected == 1 ? "" : "s"; + const std::string actualS = e.actual == 1 ? "" : "s"; + const std::string actualVerb = e.actual == 1 ? "is" : "are"; + switch (e.context) { case CountMismatch::Return: - { - const std::string expectedS = e.expected == 1 ? "" : "s"; - const std::string actualS = e.actual == 1 ? "is" : "are"; - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualS + - " returned here"; - } + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + + std::to_string(e.actual) + " " + actualVerb + " returned here"; case CountMismatch::Result: - if (e.expected > e.actual) - return "Function returns " + std::to_string(e.expected) + " values but there are only " + std::to_string(e.expected) + - " values to unpack them into."; - else - return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; + // It is alright if right hand side produces more values than the + // left hand side accepts. In this context consider only the opposite case. + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b2529840b..5e7af50c5 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -248,7 +249,7 @@ struct RequireCycle // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) std::vector getRequireCycles( - const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -282,9 +283,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(node->name); + cycle.push_back(resolver->getHumanReadableModuleName(node->name)); - cycle.push_back(top->name); + cycle.push_back(resolver->getHumanReadableModuleName(top->name)); break; } } @@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name) // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); + requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); // This is used by the type checker to replace the resulting type of cyclic modules with any sourceModule.cyclic = !requireCycles.empty(); @@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + if (FFlag::LuauClearScopes) + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index df6be767b..2fd958965 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) namespace Luau { @@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { + // If table is now bound to another one, we ignore the content of the original + if (FFlag::LuauCloneBoundTables && t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + seenTypes[typeId] = boundTo; + return; + } + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + if (!FFlag::LuauCloneBoundTables) + { + if (t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + } for (TypeId& arg : ttv->instantiatedTypeParams) arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t) if (ttv->state == TableState::Free) { - if (!t.boundTo) + if (FFlag::LuauCloneBoundTables || !t.boundTo) { if (encounteredFreeType) *encounteredFreeType = true; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp new file mode 100644 index 000000000..bf6d81aa6 --- /dev/null +++ b/Analysis/src/Quantify.cpp @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Quantify.h" + +#include "Luau/VisitTypeVar.h" + +namespace Luau +{ + +struct Quantifier +{ + ModulePtr module; + TypeLevel level; + std::vector generics; + std::vector genericPacks; + + Quantifier(ModulePtr module, TypeLevel level) + : module(module) + , level(level) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); + + return false; + } + + template + bool operator()(TypeId ty, const T& t) + { + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + if (!level.subsumes(ttv.level)) + return false; + + if (ttv.state == TableState::Free) + ttv.state = TableState::Generic; + else if (ttv.state == TableState::Unsealed) + ttv.state = TableState::Sealed; + + ttv.level = level; + + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + if (!level.subsumes(ftp.level)) + return false; + + *asMutable(tp) = GenericTypePack{level}; + genericPacks.push_back(tp); + return true; + } +}; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level) +{ + Quantifier q{std::move(module), level}; + visitTypeVar(ty, q); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; +} + +} // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index ad4d5ef43..95910b562 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -171,7 +171,7 @@ struct RequireTracerOld : AstVisitor result.exprs[call] = {fileResolver->concat(*rootName, v)}; - // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime + // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") result.exprs[call].optional = true; @@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor struct RequireTracer : AstVisitor { - RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) : result(result) , fileResolver(fileResolver) , currentModuleName(currentModuleName) @@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor // seed worklist with require arguments work.reserve(requires.size()); - for (AstExprCall* require: requires) + for (AstExprCall* require : requires) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5651af7e9..cd8180dba 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -159,15 +158,6 @@ struct StringifierState seen.erase(iter); } - static std::string generateName(size_t i) - { - std::string n; - n = char('a' + i % 26); - if (i >= 26) - n += std::to_string(i / 26); - return n; - } - std::string getName(TypeId ty) { const size_t s = result.nameMap.typeVars.size(); @@ -584,8 +574,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : &uv) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); if (isNil(el)) { @@ -649,8 +638,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : uv.parts) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); std::string saved = std::move(state.result.name); @@ -1204,4 +1192,13 @@ void dump(TypePackId ty) printf("%s\n", toString(ty, opts).c_str()); } +std::string generateName(size_t i) +{ + std::string n; + n = char('a' + i % 26); + if (i >= 26) + n += std::to_string(i / 26); + return n; +} + } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 2d3563842..dba694be0 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor struct ContainsFunctionCall : public AstVisitor { + bool alsoReturn = false; bool result = false; + ContainsFunctionCall() = default; + explicit ContainsFunctionCall(bool alsoReturn) + : alsoReturn(alsoReturn) + { + } + bool visit(AstExpr*) override { return !result; // short circuit if result is true @@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor return false; } + bool visit(AstStatReturn* stat) override + { + if (alsoReturn) + { + result = true; + return false; + } + else + return AstVisitor::visit(stat); + } + bool visit(AstExprFunction*) override { return false; @@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat) return cfc.result; } +bool containsFunctionCallOrReturn(const AstStat& stat) +{ + detail::ContainsFunctionCall cfc{true}; + const_cast(stat).visit(&cfc); + return cfc.result; +} + bool isFunction(const AstStat& stat) { return stat.is() || stat.is(); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 702d0ca2c..383bb050d 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) + namespace Luau { @@ -33,6 +35,12 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); + + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); + } } void TxnLog::concat(TxnLog rhs) @@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - seen.swap(rhs.seen); - rhs.seen.clear(); + if (!FFlag::LuauShareTxnSeen) + { + ownedSeen.swap(rhs.ownedSeen); + rhs.ownedSeen.clear(); + } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); + if (FFlag::LuauShareTxnSeen) + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); + else + return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - seen.push_back(sortedPair); + if (FFlag::LuauShareTxnSeen) + sharedSeen->push_back(sortedPair); + else + ownedSeen.push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == seen.back()); - seen.pop_back(); + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); + } + else + { + LUAU_ASSERT(sortedPair == ownedSeen.back()); + ownedSeen.pop_back(); + } } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 266c19865..49f8e0cac 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -6,6 +6,7 @@ #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data return result; } +using SyntheticNames = std::unordered_map; + namespace Luau { + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + class TypeRehydrationVisitor { - mutable std::map seen; - mutable int count = 0; + std::map seen; + int count = 0; - bool hasSeen(const void* tv) const + bool hasSeen(const void* tv) { void* ttv = const_cast(tv); auto it = seen.find(ttv); @@ -52,15 +70,16 @@ class TypeRehydrationVisitor } public: - TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) + TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions()) : allocator(alloc) + , syntheticNames(syntheticNames) , options(options) { } - AstTypePack* rehydrate(TypePackId tp) const; + AstTypePack* rehydrate(TypePackId tp); - AstType* operator()(const PrimitiveTypeVar& ptv) const + AstType* operator()(const PrimitiveTypeVar& ptv) { switch (ptv.type) { @@ -78,11 +97,11 @@ class TypeRehydrationVisitor return nullptr; } } - AstType* operator()(const AnyTypeVar&) const + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); } - AstType* operator()(const TableTypeVar& ttv) const + AstType* operator()(const TableTypeVar& ttv) { RecursionCounter counter(&count); @@ -144,12 +163,12 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), props, indexer); } - AstType* operator()(const MetatableTypeVar& mtv) const + AstType* operator()(const MetatableTypeVar& mtv) { return Luau::visit(*this, mtv.table->ty); } - AstType* operator()(const ClassTypeVar& ctv) const + AstType* operator()(const ClassTypeVar& ctv) { RecursionCounter counter(&count); @@ -176,7 +195,7 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), props); } - AstType* operator()(const FunctionTypeVar& ftv) const + AstType* operator()(const FunctionTypeVar& ftv) { RecursionCounter counter(&count); @@ -253,10 +272,12 @@ class TypeRehydrationVisitor size_t i = 0; for (const auto& el : ftv.argNames) { + std::optional* arg = &argNames.data[i++]; + if (el) - argNames.data[i++] = {AstName(el->name.c_str()), el->location}; + new (arg) std::optional(AstArgumentName(AstName(el->name.c_str()), el->location)); else - argNames.data[i++] = {}; + new (arg) std::optional(); } AstArray returnTypes; @@ -290,23 +311,23 @@ class TypeRehydrationVisitor return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); } - AstType* operator()(const Unifiable::Error&) const + AstType* operator()(const Unifiable::Error&) { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); } - AstType* operator()(const GenericTypeVar& gtv) const + AstType* operator()(const GenericTypeVar& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(gtv.name.c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); } - AstType* operator()(const Unifiable::Bound& bound) const + AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(Unifiable::Free ftv) const + AstType* operator()(const FreeTypeVar& ftv) { return allocator->alloc(Location(), std::nullopt, AstName("free")); } - AstType* operator()(const UnionTypeVar& uv) const + AstType* operator()(const UnionTypeVar& uv) { AstArray unionTypes; unionTypes.size = uv.options.size(); @@ -317,7 +338,7 @@ class TypeRehydrationVisitor } return allocator->alloc(Location(), unionTypes); } - AstType* operator()(const IntersectionTypeVar& uv) const + AstType* operator()(const IntersectionTypeVar& uv) { AstArray intersectionTypes; intersectionTypes.size = uv.parts.size(); @@ -328,23 +349,28 @@ class TypeRehydrationVisitor } return allocator->alloc(Location(), intersectionTypes); } - AstType* operator()(const LazyTypeVar& ltv) const + AstType* operator()(const LazyTypeVar& ltv) { return allocator->alloc(Location(), std::nullopt, AstName("")); } private: Allocator* allocator; + SyntheticNames* syntheticNames; const TypeRehydrationOptions& options; }; class TypePackRehydrationVisitor { public: - TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor) : allocator(allocator) + , syntheticNames(syntheticNames) , typeVisitor(typeVisitor) { + LUAU_ASSERT(allocator); + LUAU_ASSERT(syntheticNames); + LUAU_ASSERT(typeVisitor); } AstTypePack* operator()(const BoundTypePack& btp) const @@ -359,7 +385,7 @@ class TypePackRehydrationVisitor head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); for (size_t i = 0; i < tp.head.size(); i++) - head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty); AstTypePack* tail = nullptr; @@ -371,12 +397,12 @@ class TypePackRehydrationVisitor AstTypePack* operator()(const VariadicTypePack& vtp) const { - return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } AstTypePack* operator()(const GenericTypePack& gtp) const { - return allocator->alloc(Location(), AstName(gtp.name.c_str())); + return allocator->alloc(Location(), AstName(getName(allocator, syntheticNames, gtp))); } AstTypePack* operator()(const FreeTypePack& gtp) const @@ -391,12 +417,13 @@ class TypePackRehydrationVisitor private: Allocator* allocator; - const TypeRehydrationVisitor& typeVisitor; + SyntheticNames* syntheticNames; + TypeRehydrationVisitor* typeVisitor; }; -AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) { - TypePackRehydrationVisitor tprv(allocator, *this); + TypePackRehydrationVisitor tprv(allocator, syntheticNames, this); return Luau::visit(tprv, tp->ty); } @@ -431,7 +458,7 @@ class TypeAttacher : public AstVisitor { if (!type) return nullptr; - return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty); } AstArray typeAstPack(TypePackId type) @@ -443,7 +470,7 @@ class TypeAttacher : public AstVisitor result.data = static_cast(allocator->allocate(sizeof(AstType*) * v.size())); for (size_t i = 0; i < v.size(); ++i) { - result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); + result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty); } return result; } @@ -495,7 +522,7 @@ class TypeAttacher : public AstVisitor { if (FFlag::LuauTypeAliasPacks) { - variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); } else { @@ -515,6 +542,7 @@ class TypeAttacher : public AstVisitor private: Module& module; Allocator* allocator; + SyntheticNames syntheticNames; }; void attachTypeData(SourceModule& source, Module& result) @@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result) AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) { - return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); + SyntheticNames syntheticNames; + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 3a1fdfff5..38e2e5270 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" +#include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" @@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) -LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name) TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) : resolver(resolver) , iceHandler(iceHandler) + , unifierState(iceHandler) , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) @@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) return; } + int subLevel = 0; + std::vector sorted(block.body.data, block.body.data + block.body.size); toposort(sorted); for (const auto& stat : sorted) { if (const auto& typealias = stat->as()) - check(scope, *typealias, true); + { + check(scope, *typealias, subLevel, true); + ++subLevel; + } } auto protoIter = sorted.begin(); @@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } }; - int subLevel = 0; - while (protoIter != sorted.end()) { // protoIter walks forward @@ -416,7 +419,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // ``` // These both call each other, so `f` will be ordered before `g`, so the call to `g` // is typechecked before `g` has had its body checked. For this reason, there's three - // types for each functuion: before its body is checked, during checking its body, + // types for each function: before its body is checked, during checking its body, // and after its body is checked. // // We currently treat the before-type and the during-type as the same, @@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCall(**protoIter)) + if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter)) { while (checkIter != protoIter) { @@ -1080,7 +1083,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type - // in case this function has a differing signature. The signature discrepency will be caught in checkBlock. + // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) globalBindings[name] = oldBinding; else @@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare) +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) { // This function should be called at most twice for each type alias. // Once with forwardDeclare, and once without. @@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = childScope(scope, typealias.location); + ScopePtr aliasScope = + FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); if (FFlag::LuauTypeAliasPacks) { - auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); FreeTypeVar* ftv = getMutable(ty); @@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo { ScopePtr funScope = childFunctionScope(scope, global.location); - auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks); + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); @@ -1610,25 +1614,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = resolveLValue(scope, *lvalue)) return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - if (FFlag::LuauExtraNilRecovery) - lhsType = stripFromNilAndReport(lhsType, expr.expr->location); + lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - if (!FFlag::LuauMissingUnionPropertyError) - reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(lhsType)) - { - if (std::optional ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false)) - return {*ty}; - } - } - return {errorType}; } @@ -1694,61 +1684,37 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (const UnionTypeVar* utv = get(type)) { - if (FFlag::LuauMissingUnionPropertyError) - { - std::vector goodOptions; - std::vector badOptions; - - for (TypeId t : utv) - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + std::vector goodOptions; + std::vector badOptions; - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - reportError(location, UnknownProperty{type, name}); - else - reportError(location, MissingUnionProperty{type, badOptions, name}); - } - return std::nullopt; - } - - std::vector result = reduceUnion(goodOptions); - - if (result.size() == 1) - return result[0]; + for (TypeId t : utv) + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - return addType(UnionTypeVar{std::move(result)}); + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); } - else - { - std::vector options; - for (TypeId t : utv->options) + if (!badOptions.empty()) + { + if (addErrors) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - options.push_back(*ty); + if (goodOptions.empty()) + reportError(location, UnknownProperty{type, name}); else - return std::nullopt; + reportError(location, MissingUnionProperty{type, badOptions, name}); } + return std::nullopt; + } - std::vector result = reduceUnion(options); + std::vector result = reduceUnion(goodOptions); - if (result.size() == 1) - return result[0]; + if (result.size() == 1) + return result[0]; - return addType(UnionTypeVar{std::move(result)}); - } + return addType(UnionTypeVar{std::move(result)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1765,7 +1731,7 @@ std::optional TypeChecker::getIndexTypeFromType( // If no parts of the intersection had the property we looked up for, it never existed at all. if (parts.empty()) { - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; } @@ -1779,7 +1745,7 @@ std::optional TypeChecker::getIndexTypeFromType( return addType(IntersectionTypeVar{result}); } - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; @@ -2062,8 +2028,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn case AstExprUnary::Len: tablify(operandType); - if (FFlag::LuauExtraNilRecovery) - operandType = stripFromNilAndReport(operandType, expr.location); + operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) return {errorType}; @@ -2635,8 +2600,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Name name = expr.index.value; - if (FFlag::LuauExtraNilRecovery) - lhs = stripFromNilAndReport(lhs, expr.expr->location); + lhs = stripFromNilAndReport(lhs, expr.expr->location); if (TableTypeVar* lhsTable = getMutableTableType(lhs)) { @@ -2710,8 +2674,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); - if (FFlag::LuauExtraNilRecovery) - exprType = stripFromNilAndReport(exprType, expr.expr->location); + exprType = stripFromNilAndReport(exprType, expr.expr->location); TypeId indexType = checkExpr(scope, *expr.index).type; @@ -2738,10 +2701,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { - if (FFlag::LuauExtraNilRecovery) - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - else - reportError(TypeError{expr.location, NotATable{exprType}}); + reportError(TypeError{expr.expr->location, NotATable{exprType}}); return std::pair(errorType, nullptr); } @@ -2910,7 +2870,7 @@ std::pair TypeChecker::checkFunctionSignature( if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); } TypePackId retPack; @@ -3016,9 +2976,6 @@ std::pair TypeChecker::checkFunctionSignature( if (expectedArgsCurr != expectedArgsEnd) { argType = *expectedArgsCurr; - - if (!FFlag::LuauInferFunctionArgsFix) - ++expectedArgsCurr; } else if (auto expectedArgsTail = expectedArgsCurr.tail()) { @@ -3034,7 +2991,7 @@ std::pair TypeChecker::checkFunctionSignature( funScope->bindings[local] = {argType, local->location}; argTypes.push_back(argType); - if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd) + if (expectedArgsCurr != expectedArgsEnd) ++expectedArgsCurr; } @@ -3170,7 +3127,7 @@ void TypeChecker::checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) { /* Important terminology refresher: - * A function requires paramaters. + * A function requires parameters. * To call a function, you supply arguments. */ TypePackIterator argIter = begin(argPack); @@ -3402,8 +3359,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!FFlag::LuauRankNTypes) instantiate(scope, selfType, expr.func->location); - if (FFlag::LuauExtraNilRecovery) - selfType = stripFromNilAndReport(selfType, expr.func->location); + selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) { @@ -3412,34 +3368,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - if (!FFlag::LuauMissingUnionPropertyError) - reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(selfType)) - { - if (std::optional propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false)) - { - selfType = *strippedUnion; - - functionType = *propTy; - actualFunctionType = instantiate(scope, functionType, expr.func->location); - } - } - - if (!actualFunctionType) - { - functionType = errorType; - actualFunctionType = errorType; - } - } - else - { - functionType = errorType; - actualFunctionType = errorType; - } + functionType = errorType; + actualFunctionType = errorType; } } else @@ -3555,8 +3485,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& errors) { - if (FFlag::LuauExtraNilRecovery) - fn = stripFromNilAndReport(fn, expr.func->location); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) { @@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) return ty; + if (FFlag::LuauQuantifyInPlace2) + { + Luau::quantify(currentModule, ty, scope->level); + return ty; + } + quantification.level = scope->level; quantification.generics.clear(); quantification.genericPacks.clear(); @@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); } // TODO: better error message CLI-39912 @@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); + const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; + std::vector generics; for (const AstName& generic : genericNames) { @@ -5063,12 +5000,12 @@ std::pair, std::vector> TypeChecker::createGener { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) - cached = addType(GenericTypeVar{scope->level, n}); + cached = addType(GenericTypeVar{level, n}); g = cached; } else { - g = addType(Unifiable::Generic{scope->level, n}); + g = addType(Unifiable::Generic{level, n}); } generics.push_back(g); @@ -5093,12 +5030,12 @@ std::pair, std::vector> TypeChecker::createGener { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); g = cached; } else { - g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } genericPacks.push_back(g); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e963fc74e..e82f7519d 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) namespace Luau { @@ -217,8 +218,7 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); - primitiveType && primitiveType->metatable) + else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -1490,4 +1490,86 @@ std::vector filterMap(TypeId type, TypeIdPredicate predicate) return {}; } +static Tags* getTags(TypeId ty) +{ + ty = follow(ty); + + if (auto ftv = getMutable(ty)) + return &ftv->tags; + else if (auto ttv = getMutable(ty)) + return &ttv->tags; + else if (auto ctv = getMutable(ty)) + return &ctv->tags; + + return nullptr; +} + +void attachTag(TypeId ty, const std::string& tagName) +{ + if (!FFlag::LuauRefactorTagging) + { + if (auto ftv = getMutable(ty)) + { + ftv->tags.emplace_back(tagName); + } + else + { + LUAU_ASSERT(!"Got a non functional type"); + } + } + else + { + if (auto tags = getTags(ty)) + tags->push_back(tagName); + else + LUAU_ASSERT(!"This TypeId does not support tags"); + } +} + +void attachTag(Property& prop, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + + prop.tags.push_back(tagName); +} + +// We would ideally not expose this because it could cause a footgun. +// If the Base class has a tag and you ask if Derived has that tag, it would return false. +// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. +bool hasTag(const Tags& tags, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + return std::find(tags.begin(), tags.end(), tagName) != tags.end(); +} + +bool hasTag(TypeId ty, const std::string& tagName) +{ + ty = follow(ty); + + // We special case classes because getTags only returns a pointer to one vector of tags. + // But classes has multiple vector of tags, represented throughout the hierarchy. + if (auto ctv = get(ty)) + { + while (ctv) + { + if (hasTag(ctv->tags, tagName)) + return true; + else if (!ctv->parent) + return false; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + else if (auto tags = getTags(ty)) + return hasTag(*tags, tagName); + + return false; +} + +bool hasTag(const Property& prop, const std::string& tagName) +{ + return hasTag(prop.tags, tagName); +} + } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 117cbc289..2539650a4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -7,6 +7,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/VisitTypeVar.h" #include @@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) +LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) namespace Luau { +struct SkipCacheForType +{ + SkipCacheForType(const DenseHashMap& skipCacheForType) + : skipCacheForType(skipCacheForType) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.boundTo) + { + result = true; + return false; + } + + if (ttv.state != TableState::Sealed) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypeId ty, const T& t) + { + const bool* prev = skipCacheForType.find(ty); + + if (prev && *prev) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const BoundTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const GenericTypePack& ftp) + { + result = true; + return false; + } + + const DenseHashMap& skipCacheForType; + bool result = false; +}; static std::optional hasUnificationTooComplex(const ErrorVec& errors) { @@ -39,7 +130,7 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) @@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , variance(variance) , counters(&countersData) , counters_DEPRECATED(std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) +{ + LUAU_ASSERT(sharedState.iceHandler); +} + +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + : types(types) + , mode(mode) + , globalScope(std::move(globalScope)) + , log(ownedSeen) + , location(location) + , variance(variance) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(seen) + , log(sharedSeen) , location(location) , variance(variance) , counters(counters ? counters : &countersData) , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); } void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); + bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + auto& cache = sharedState.cachedUnify; + + // What if the types are immutable and we proved their relation before + if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + return; + // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. @@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauUnionHeuristic) { + bool found = false; + const std::string* subName = getName(subTy); if (subName) { @@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) + { + found = true; + startIndex = i; + break; + } + } + } + + if (!found && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) { startIndex = i; break; @@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; - for (TypeId type : uv->parts) + size_t startIndex = 0; + + if (cacheEnabled) { + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[i]; + + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } + } + } + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(superTy, type, isFunctionCall); @@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool tryUnifyFunctions(superTy, subTy, isFunctionCall); else if (get(superTy) && get(subTy)) + { tryUnifyTables(superTy, subTy, isIntersection); + if (cacheEnabled && errors.empty()) + cacheResult(superTy, subTy); + } + // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. else if (get(superTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); @@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } +void Unifier::cacheResult(TypeId superTy, TypeId subTy) +{ + LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); + + bool* superTyInfo = sharedState.skipCacheForType.find(superTy); + + if (superTyInfo && *superTyInfo) + return; + + bool* subTyInfo = sharedState.skipCacheForType.find(subTy); + + if (subTyInfo && *subTyInfo) + return; + + auto skipCacheFor = [this](TypeId ty) { + SkipCacheForType visitor{sharedState.skipCacheForType}; + visitTypeVarOnce(ty, visitor, sharedState.seenAny); + + sharedState.skipCacheForType[ty] = visitor.result; + + return visitor.result; + }; + + if (!superTyInfo && skipCacheFor(superTy)) + return; + + if (!subTyInfo && skipCacheFor(subTy)) + return; + + sharedState.cachedUnify.insert({superTy, subTy}); + + if (variance == Invariant) + sharedState.cachedUnify.insert({subTy, superTy}); +} + struct WeirdIter { TypePackId packId; @@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(superTp, subTp, isFunctionCall); } /* @@ -650,11 +837,11 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking a return value, we swap these to produce - // the expected error message. + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. size_t expectedSize = size(superTp); size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result || ctx == CountMismatch::Return) + if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); @@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) std::vector missingProperties; std::vector extraProperties; + // Optimization: First test that the property sets are compatible without doing any recursive unification + if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + { + for (const auto& [propName, superProp] : lt->props) + { + auto subIter = rt->props.find(propName); + if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + missingProperties.push_back(propName); + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + } + + // And vice versa if we're invariant + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + { + for (const auto& [propName, subProp] : rt->props) + { + auto superIter = lt->props.find(propName); + if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + extraProperties.push_back(propName); + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + } + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) innerState.log.rollback(); } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (rt->state == TableState::Free) { log(rt); @@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) lt->props[name] = clone; } else if (variance == Covariant) - {} + { + } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (lt->state == TableState::Free) { log(lt); @@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); } else - return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); + return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) @@ -1247,7 +1471,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. // Otherwise, we would falsely generate an extra-property-error for 's' in this code: // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking agaist the table '{n: number}'. + // When checking against the table '{n: number}'. if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) { // Check for extra properties in the subTy @@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { std::vector queue = {ty}; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); + + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); + } } else { @@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { std::vector queue; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack(queue, tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); + + queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); + } } else { @@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTy.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTy.clear(); + + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTy_DEPRECATED.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); + return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) @@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTp.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTp_DEPRECATED.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) @@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; + if (FFlag::LuauShareTxnSeen) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + else + return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const @@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::ice(const std::string& message, const Location& location) { - iceHandler->ice(message, location); + sharedState.iceHandler->ice(message, location); } void Unifier::ice(const std::string& message) { - iceHandler->ice(message); + sharedState.iceHandler->ice(message); } } // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 641dfd3c3..503eca61d 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair createScopeDa } // namespace Luau // Regular scope -#define LUAU_TIMETRACE_SCOPE(name, category) \ +#define LUAU_TIMETRACE_SCOPE(name, category) \ static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) // A scope without nested scopes that may be skipped if the time it took is less than the threshold -#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) // Extra key/value data can be added to regular scopes -#define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ - if (FFlag::DebugLuauTimeTracing) \ +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ lttScopeStatic.second.eventArgument(name, value); \ } while (false) @@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair createScopeDa #define LUAU_TIMETRACE_SCOPE(name, category) #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) #define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ + do \ + { \ } while (false) #endif diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 40026d8be..846bc0ba9 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1593,7 +1593,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { Location location = lexer.current().location; - // For a missing type annoation, capture 'space' between last token and the next one + // For a missing type annotation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e6aab20e5..ded50e53e 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -77,7 +77,10 @@ struct GlobalContext // Ideally we would want all ThreadContext destructors to run // But in VS, not all thread_local object instances are destroyed for (ThreadContext* context : threads) - context->flushEvents(); + { + if (!context->events.empty()) + context->flushEvents(); + } if (traceFile) fclose(traceFile); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 6baa21ea6..4968d0800 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -169,7 +169,17 @@ static std::string runCode(lua_State* L, const std::string& source) } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(T, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + error += "\nstack backtrace:\n"; error += lua_debugtrace(T); @@ -322,7 +332,17 @@ static bool runFile(const char* name, lua_State* GL) } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(L, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(L, -1)) + { + error = str; + } + error += "\nstacktrace:\n"; error += lua_debugtrace(L); diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 07be2e749..4b03ed1c7 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -208,14 +208,14 @@ enum LuauOpcode LOP_MODK, LOP_POWK, - // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthful) and put the result into target register + // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthy) and put the result into target register // A: target register // B: source register 1 // C: source register 2 LOP_AND, LOP_OR, - // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthful) and put the result into target register + // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthy) and put the result into target register // A: target register // B: source register // C: constant table index (0..255) diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 797ee20dc..7750a1d9f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -718,9 +718,9 @@ struct Compiler } // compile expr to target temp register - // if the expr (or not expr if onlyTruth is false) is truthful, jump via skipJump - // if the expr (or not expr if onlyTruth is false) is falseful, fall through (target isn't guaranteed to be updated in this case) - // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthfulness of the expression + // if the expr (or not expr if onlyTruth is false) is truthy, jump via skipJump + // if the expr (or not expr if onlyTruth is false) is falsy, fall through (target isn't guaranteed to be updated in this case) + // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthiness of the expression void compileConditionValue(AstExpr* node, const uint8_t* target, std::vector& skipJump, bool onlyTruth) { // Optimization: we don't need to compute constant values @@ -728,7 +728,7 @@ struct Compiler if (cv && cv->type != Constant::Type_Unknown) { - // note that we only need to compute the value if it's truthful; otherwise we cal fall through + // note that we only need to compute the value if it's truthy; otherwise we cal fall through if (cv->isTruthful() == onlyTruth) { if (target) @@ -747,7 +747,7 @@ struct Compiler case AstExprBinary::And: case AstExprBinary::Or: { - // disambiguation: there's 4 cases (we only need truthful or falseful results based on onlyTruth) + // disambiguation: there's 4 cases (we only need truthy or falsy results based on onlyTruth) // onlyTruth = 1: a and b transforms to a ? b : dontcare // onlyTruth = 1: a or b transforms to a ? a : a // onlyTruth = 0: a and b transforms to !a ? a : b @@ -791,8 +791,8 @@ struct Compiler if (target) { // since target is a temp register, we'll initialize it to 1, and then jump if the comparison is true - // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falseful results - // when we only care about falseful values instead of truthful values, the process is the same but with flipped conditionals + // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falsy results + // when we only care about falsy values instead of truthy values, the process is the same but with flipped conditionals bytecode.emitABC(LOP_LOADB, *target, onlyTruth ? 1 : 0, 0); } diff --git a/Makefile b/Makefile index 0056870b6..7788251d8 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +.SUFFIXES: MAKEFLAGS+=-r -j8 COMMA=, @@ -107,6 +108,7 @@ coverage: $(TESTS_TARGET) rm default.profraw default-flags.profraw llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests llvm-cov report -ignore-filename-regex=\(tests\|extern\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests + llvm-cov export -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info format: find . -name '*.h' -or -name '*.cpp' | xargs clang-format -i diff --git a/Sources.cmake b/Sources.cmake index 83ed52301..c30cf77d9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h + Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h @@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeVar.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h + Analysis/include/Luau/UnifierSharedState.h Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitTypeVar.h @@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/Module.cpp Analysis/src/Predicate.cpp + Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp Analysis/src/Substitution.cpp diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 7a09ae9f8..30cffaff8 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -76,7 +76,7 @@ struct luaL_Buffer char buffer[LUA_BUFFERSIZE]; }; -// when internal buffer storage is exhaused, a mutable string value 'storage' will be placed on the stack +// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) // with the exception of luaL_addvalue that expects the value at the top and string buffer further away (top-2) // functions that accept a 'boxloc' support string buffer placement at any location in the stack diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 013153605..f2e97c669 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) luaC_checkGC(L); luaC_checkthreadsleep(L); Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); - memcpy(u->data + sz, &dtor, sizeof(dtor)); + memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); return u->data; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index e37618f7b..2a684ee4e 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -313,7 +313,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired { size_t newsize = currentsize + currentsize / 2; - // check for size oveflow + // check for size overflow if (SIZE_MAX - desiredsize < currentsize) luaL_error(L, "buffer too large"); diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index c0b50b969..9724c0e72 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) static const char* const statnames[] = {"running", "suspended", "normal", "dead"}; -static int costatus(lua_State* L, lua_State* co) +static int auxstatus(lua_State* L, lua_State* co) { if (co == L) return CO_RUN; @@ -25,7 +23,7 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; if (co->status == LUA_BREAK) return CO_NOR; - if (co->status != 0) /* some error occured */ + if (co->status != 0) /* some error occurred */ return CO_DEAD; if (co->ci != co->base_ci) /* does it have frames? */ return CO_NOR; @@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; /* initial state */ } -static int luaB_costatus(lua_State* L) +static int costatus(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); - lua_pushstring(L, statnames[costatus(L, co)]); + lua_pushstring(L, statnames[auxstatus(L, co)]); return 1; } @@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) // error handling for edge cases if (co->status != LUA_YIELD) { - int status = costatus(L, co); + int status = auxstatus(L, co); if (status != CO_SUS) { lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]); @@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co) } } -static int luaB_coresumefinish(lua_State* L, int r) +static int coresumefinish(lua_State* L, int r) { if (r < 0) { @@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r) } } -static int luaB_coresumey(lua_State* L) +static int coresumey(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_coresumecont(lua_State* L, int status) +static int coresumecont(lua_State* L, int status) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_auxwrapfinish(lua_State* L, int r) +static int auxwrapfinish(lua_State* L, int r) { if (r < 0) { @@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r) return r; } -static int luaB_auxwrapy(lua_State* L) +static int auxwrapy(lua_State* L) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); int narg = cast_int(L->top - L->base); @@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_auxwrapcont(lua_State* L, int status) +static int auxwrapcont(lua_State* L, int status) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); @@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_cocreate(lua_State* L) +static int cocreate(lua_State* L) { luaL_checktype(L, 1, LUA_TFUNCTION); lua_State* NL = lua_newthread(L); - - if (FFlag::LuauPreferXpush) - { - lua_xpush(L, NL, 1); // push function on top of NL - } - else - { - lua_pushvalue(L, 1); /* move function to top */ - lua_xmove(L, NL, 1); /* move function from L to NL */ - } - + lua_xpush(L, NL, 1); // push function on top of NL return 1; } -static int luaB_cowrap(lua_State* L) +static int cowrap(lua_State* L) { - luaB_cocreate(L); + cocreate(L); - lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont); + lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } -static int luaB_yield(lua_State* L) +static int coyield(lua_State* L) { int nres = cast_int(L->top - L->base); return lua_yield(L, nres); } -static int luaB_corunning(lua_State* L) +static int corunning(lua_State* L) { if (lua_pushthread(L)) lua_pushnil(L); /* main thread is not a coroutine */ return 1; } -static int luaB_yieldable(lua_State* L) +static int coyieldable(lua_State* L) { lua_pushboolean(L, lua_isyieldable(L)); return 1; } static const luaL_Reg co_funcs[] = { - {"create", luaB_cocreate}, - {"running", luaB_corunning}, - {"status", luaB_costatus}, - {"wrap", luaB_cowrap}, - {"yield", luaB_yield}, - {"isyieldable", luaB_yieldable}, + {"create", cocreate}, + {"running", corunning}, + {"status", costatus}, + {"wrap", cowrap}, + {"yield", coyield}, + {"isyieldable", coyieldable}, {NULL, NULL}, }; @@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont); + lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 39a615978..328b47e69 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -18,6 +18,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) +LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) /* ** {====================================================== @@ -536,7 +537,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } - // an error occured, check if we have a protected error callback + if (FFlag::LuauCcallRestoreFix) + { + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; + } + + // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) { L->global->cb.debugprotectederror(L); @@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - L->nCcalls = oldnCcalls; + if (!FFlag::LuauCcallRestoreFix) + { + L->nCcalls = oldnCcalls; + } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 4fe1c3418..72807f0f3 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -37,7 +37,7 @@ /* results from luaD_precall */ #define PCRLUA 0 /* initiated a call to a Lua function */ #define PCRC 1 /* did a call to a C function */ -#define PCRYIELD 2 /* C funtion yielded */ +#define PCRYIELD 2 /* C function yielded */ /* type of protected functions, to be ran by `runprotected' */ typedef void (*Pfunc)(lua_State* L, void* ud); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 0b5430269..648785697 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -76,7 +76,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) if (p->v == level) { /* found a corresponding upvalue? */ if (isdead(g, obj2gco(p))) /* is it dead? */ - changewhite(obj2gco(p)); /* ressurect it */ + changewhite(obj2gco(p)); /* resurrect it */ return p; } pp = &p->next; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 510a9f548..6553009ff 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,11 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) -LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) +LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, g->gcstats.currcycle.marktime += seconds; // atomic step had to be performed during the switch and it's tracked separately - if (g->gcstate == GCSsweepstring) + if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; + case GCSatomic: + g->gcstats.currcycle.atomictime += seconds; + break; case GCSsweepstring: case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; break; + default: + LUAU_ASSERT(!"Unexpected GC state"); } if (assist) @@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h) if (h->metatable) markobject(g, cast_to(Table*, h->metatable)); - if (FFlag::LuauShrinkWeakTables) - { - /* is there a weak mode? */ - if (const char* modev = gettablemode(g, h)) - { - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } - } - } - else + /* is there a weak mode? */ + if (const char* modev = gettablemode(g, h)) { - const TValue* mode = gfasttm(g, h->metatable, TM_MODE); - if (mode && ttisstring(mode)) - { /* is there a weak mode? */ - const char* modev = svalue(mode); - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } + weakkey = (strchr(modev, 'k') != NULL); + weakvalue = (strchr(modev, 'v') != NULL); + if (weakkey || weakvalue) + { /* is really weak? */ + h->gclist = g->weak; /* must be cleared after GC, ... */ + g->weak = obj2gco(h); /* ... so put in the appropriate list */ } } @@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack) for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); /* final traversal? */ - if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack)) + if (g->gcstate == GCSatomic || clearstack) { StkId stack_end = l->stack + l->stacksize; for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */ @@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g) lua_State* th = gco2th(o); g->gray = th->gclist; - if (FFlag::LuauGcFullSkipInactiveThreads) - { - LUAU_ASSERT(!luaC_threadsleeping(th)); - - // threads that are executing and the main thread are not deactivated - bool active = luaC_threadactive(th) || th == th->global->mainthread; + LUAU_ASSERT(!luaC_threadsleeping(th)); - if (!active && g->gcstate == GCSpropagate) - { - traversestack(g, th, /* clearstack= */ true); + // threads that are executing and the main thread are not deactivated + bool active = luaC_threadactive(th) || th == th->global->mainthread; - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); - } - else - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); + if (!active && g->gcstate == GCSpropagate) + { + traversestack(g, th, /* clearstack= */ true); - traversestack(g, th, /* clearstack= */ false); - } + l_setbit(th->stackstate, THREAD_SLEEPINGBIT); } else { @@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g) } } -static void propagateall(global_State* g) +static size_t propagateall(global_State* g) { + size_t work = 0; while (g->gray) { - propagatemark(g); + work += propagatemark(g); } + return work; } /* @@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o) /* ** clear collected entries from weaktables */ -static void cleartable(lua_State* L, GCObject* l) +static size_t cleartable(lua_State* L, GCObject* l) { + size_t work = 0; while (l) { Table* h = gco2h(l); + work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + int i = h->sizearray; while (i--) { @@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l) { LuaNode* n = gnode(h, i); - if (FFlag::LuauShrinkWeakTables) + // non-empty entry? + if (!ttisnil(gval(n))) { - // non-empty entry? - if (!ttisnil(gval(n))) - { - // can we clear key or value? - if (iscleared(gkey(n)) || iscleared(gval(n))) - { - setnilvalue(gval(n)); /* remove value ... */ - removeentry(n); /* remove entry from table */ - } - else - { - activevalues++; - } - } - } - else - { - if (!ttisnil(gval(n)) && /* non-empty entry? */ - (iscleared(gkey(n)) || iscleared(gval(n)))) + // can we clear key or value? + if (iscleared(gkey(n)) || iscleared(gval(n))) { setnilvalue(gval(n)); /* remove value ... */ removeentry(n); /* remove entry from table */ } + else + { + activevalues++; + } } } - if (FFlag::LuauShrinkWeakTables) + if (const char* modev = gettablemode(L->global, h)) { - if (const char* modev = gettablemode(L->global, h)) + // are we allowed to shrink this weak table? + if (strchr(modev, 's')) { - // are we allowed to shrink this weak table? - if (strchr(modev, 's')) - { - // shrink at 37.5% occupancy - if (activevalues < sizenode(h) * 3 / 8) - luaH_resizehash(L, h, activevalues); - } + // shrink at 37.5% occupancy + if (activevalues < sizenode(h) * 3 / 8) + luaH_resizehash(L, h, activevalues); } } l = h->gclist; } + return work; } static void shrinkstack(lua_State* L) @@ -655,37 +619,49 @@ static void markroot(lua_State* L) g->gcstate = GCSpropagate; } -static void remarkupvals(global_State* g) +static size_t remarkupvals(global_State* g) { - UpVal* uv; - for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + size_t work = 0; + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { + work += sizeof(UpVal); LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); if (isgray(obj2gco(uv))) markvalue(g, uv->v); } + return work; } -static void atomic(lua_State* L) +static size_t atomic(lua_State* L) { global_State* g = L->global; - g->gcstate = GCSatomic; + size_t work = 0; + + if (FFlag::LuauSeparateAtomic) + { + LUAU_ASSERT(g->gcstate == GCSatomic); + } + else + { + g->gcstate = GCSatomic; + } + /* remark occasional upvalues of (maybe) dead threads */ - remarkupvals(g); + work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ - propagateall(g); + work += propagateall(g); /* remark weak tables */ g->gray = g->weak; g->weak = NULL; LUAU_ASSERT(!iswhite(obj2gco(g->mainthread))); markobject(g, L); /* mark running thread */ markmt(g); /* mark basic metatables (again) */ - propagateall(g); + work += propagateall(g); /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; - propagateall(g); - cleartable(L, g->weak); /* remove collected objects from weak tables */ + work += propagateall(g); + work += cleartable(L, g->weak); /* remove collected objects from weak tables */ g->weak = NULL; /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); @@ -693,7 +669,12 @@ static void atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - GC_INTERRUPT(GCSatomic); + if (!FFlag::LuauSeparateAtomic) + { + GC_INTERRUPT(GCSatomic); + } + + return work; } static size_t singlestep(lua_State* L) @@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + if (g->gray) { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; + g->gcstats.currcycle.markitems++; - g->gcstate = GCSpropagateagain; - } + cost = propagatemark(g); } else { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L) } else /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { size_t traversedcount = 0; @@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; @@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + while (g->gray && cost < limit) { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; - - cost += propagatemark(g); - } - - if (!g->gray) - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; + g->gcstats.currcycle.markitems++; - g->gcstate = GCSpropagateagain; - } + cost += propagatemark(g); } - else - { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; - - cost += propagatemark(g); - } - - if (!g->gray) /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + if (!g->gray) + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { while (g->sweepstrgc < g->strt.size && cost < limit) @@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; } @@ -1008,7 +978,7 @@ void luaC_step(lua_State* L, bool assist) startGcCycleStats(g); int lastgcstate = g->gcstate; - double lastttimestamp = lua_clock(); + double lasttimestamp = lua_clock(); if (FFlag::LuauConsolidatedStep) { @@ -1034,15 +1004,15 @@ void luaC_step(lua_State* L, bool assist) double now = lua_clock(); - recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); - lastttimestamp = now; + lasttimestamp = now; lastgcstate = g->gcstate; } } while (lim > 0 && g->gcstate != GCSpause); } - recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); // at the end of the last cycle if (g->gcstate == GCSpause) @@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L) g->weak = NULL; g->gcstate = GCSsweepstring; } - LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain); + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); /* finish any pending sweep phase */ while (g->gcstate != GCSpause) { @@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L) void luaC_barrierupval(lua_State* L, GCObject* v) { - if (FFlag::LuauGcFullSkipInactiveThreads) - { - global_State* g = L->global; - LUAU_ASSERT(iswhite(v) && !isdead(g, v)); + global_State* g = L->global; + LUAU_ASSERT(iswhite(v) && !isdead(g, v)); - if (keepinvariant(g)) - reallymarkobject(g, v); - } + if (keepinvariant(g)) + reallymarkobject(g, v); } void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) @@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index dc780bba5..f434e5064 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,8 +6,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - /* ** Possible states of the Garbage Collector */ @@ -25,10 +23,10 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) ** still-black objects. The invariant is restored when sweep ends and ** all objects are white again. */ -#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain) +#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) /* -** some userful bit tricks +** some useful bit tricks */ #define resetbits(x, m) ((x) &= cast_to(uint8_t, ~(m))) #define setbits(x, m) ((x) |= (m)) @@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); LUAI_FUNC void luaC_wakethread(lua_State* L); -LUAI_FUNC const char* luaC_statename(int state); \ No newline at end of file +LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 2759f3b80..d8b265cba 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) if (page->freeNext >= 0) { - block = page->data + page->freeNext; + block = &page->data + page->freeNext; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); page->freeNext -= page->blockSize; diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index d77e17c9b..18ee1cda5 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u) void (*dtor)(void*) = nullptr; if (u->tag == UTAG_IDTOR) - memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor)); + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); else if (u->tag) dtor = L->global->udatagc[u->tag]; diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index b932a85b3..a168b6522 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens -template +template struct TempBuffer { lua_State* L; @@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint32_t mainid = readVarInt(data, size, offset); Proto* main = protos[mainid]; + luaC_checkthreadsleep(L); + Closure* cl = luaF_newLclosure(L, 0, envt, main); setclvalue(L, L->top, cl); incr_top(L); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua new file mode 100644 index 000000000..87b9abfd4 --- /dev/null +++ b/bench/tests/chess.lua @@ -0,0 +1,849 @@ + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local RANKS = "12345678" +local FILES = "abcdefgh" +local PieceSymbols = "PpRrNnBbQqKk" +local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"} +local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" + +-- +-- Lua 5.2 Compat +-- + +if not table.create then + function table.create(n, v) + local result = {} + for i=1,n do result[i] = v end + return result + end +end + +if not table.move then + function table.move(a, from, to, start, target) + local dx = start - from + for i=from,to do + target[i+dx] = a[i] + end + end +end + + +-- +-- Utils +-- + +local function square(s) + return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9 +end + +local function squareName(n) + local file = n % 8 + local rank = (n-file)/8 + return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1) +end + +local function moveName(v ) + local from = bit32.extract(v, 6, 6) + local to = bit32.extract(v, 0, 6) + local piece = bit32.extract(v, 20, 4) + local captured = bit32.extract(v, 25, 4) + + local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to) + + if bit32.extract(v,14) == 1 then + if to > from then + return "O-O" + else + return "O-O-O" + end + end + + local promote = bit32.extract(v,15,4) + if promote ~= 0 then + move = move .. "=" .. PieceSymbols:sub(promote,promote) + end + return move +end + +local function ucimove(m) + local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6)) + local promote = bit32.extract(m,15,4) + if promote > 0 then + mm = mm .. PieceSymbols:sub(promote,promote):lower() + end + return mm +end + +local _utils = {squareName, moveName} + +-- +-- Bitboards +-- + +local Bitboard = {} + + +function Bitboard:toString() + local out = {} + + local src = self.h + for x=7,0,-1 do + table.insert(out, RANKS:sub(x+1,x+1)) + table.insert(out, " ") + local bit = bit32.lshift(1,(x%4) * 8) + for x=0,7 do + if bit32.band(src, bit) ~= 0 then + table.insert(out, "x ") + else + table.insert(out, "- ") + end + bit = bit32.lshift(bit, 1) + end + if x == 4 then + src = self.l + end + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h) + return table.concat(out) +end + + +function Bitboard.from(l ,h ) + return setmetatable({l=l, h=h}, Bitboard) +end + +Bitboard.zero = Bitboard.from(0,0) +Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF) + +local Rank1 = Bitboard.from(0x000000FF, 0) +local Rank3 = Bitboard.from(0x00FF0000, 0) +local Rank6 = Bitboard.from(0, 0x0000FF00) +local Rank8 = Bitboard.from(0, 0xFF000000) +local FileA = Bitboard.from(0x01010101, 0x01010101) +local FileB = Bitboard.from(0x02020202, 0x02020202) +local FileC = Bitboard.from(0x04040404, 0x04040404) +local FileD = Bitboard.from(0x08080808, 0x08080808) +local FileE = Bitboard.from(0x10101010, 0x10101010) +local FileF = Bitboard.from(0x20202020, 0x20202020) +local FileG = Bitboard.from(0x40404040, 0x40404040) +local FileH = Bitboard.from(0x80808080, 0x80808080) + +local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH} + +-- These masks are filled out below for all files +local RightMasks = {FileH} +local LeftMasks = {FileA} + + + +local function popcnt32(i) + i = i - bit32.band(bit32.rshift(i,1), 0x55555555) + i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333) + return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24) +end + +function Bitboard:up() + return self:lshift(8) +end + +function Bitboard:down() + return self:rshift(8) +end + +function Bitboard:right() + return self:band(FileH:inverse()):lshift(1) +end + +function Bitboard:left() + return self:band(FileA:inverse()):rshift(1) +end + +function Bitboard:move(x,y) + local out = self + + if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end + if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end + + if y < 0 then out = out:rshift(-8 * y) end + if y > 0 then out = out:lshift(8 * y) end + return out +end + + +function Bitboard:popcnt() + return popcnt32(self.l) + popcnt32(self.h) +end + +function Bitboard:band(other ) + return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h)) +end + +function Bitboard:bandnot(other ) + return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h))) +end + +function Bitboard:bandempty(other ) + return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0 +end + +function Bitboard:bor(other ) + return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h)) +end + +function Bitboard:bxor(other ) + return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h)) +end + +function Bitboard:inverse() + return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF)) +end + +function Bitboard:empty() + return self.h == 0 and self.l == 0 +end + +function Bitboard:ctz() + local target = self.l + local offset = 0 + local result = 0 + + if target == 0 then + target = self.h + result = 32 + end + + if target == 0 then + return 64 + end + + while bit32.extract(target, offset) == 0 do + offset = offset + 1 + end + + return result + offset +end + +function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 +end + + +function Bitboard:lshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + + if amt > 31 then + return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + end + + local l = bit32.lshift(self.l, amt) + local h = bit32.bor( + bit32.lshift(self.h, amt), + bit32.extract(self.l, 32-amt, amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:rshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + local h = bit32.rshift(self.h, amt) + local l = bit32.bor( + bit32.rshift(self.l, amt), + bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:index(i) + if i > 31 then + return bit32.extract(self.h, i - 32) + else + return bit32.extract(self.l, i) + end +end + +function Bitboard:set(i , v) + if i > 31 then + return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32)) + else + return Bitboard.from(bit32.replace(self.l, v, i), self.h) + end +end + +function Bitboard:isolate(i) + return self:band(Bitboard.some(i)) +end + +function Bitboard.some(idx ) + return Bitboard.zero:set(idx, 1) +end + +Bitboard.__index = Bitboard +Bitboard.__tostring = Bitboard.toString + +for i=2,8 do + RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH) + LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA) +end +-- +-- Board +-- + +local Board = {} + + +function Board.new() + local boards = table.create(12, Bitboard.zero) + boards.ocupied = Bitboard.zero + boards.white = Bitboard.zero + boards.black = Bitboard.zero + boards.unocupied = Bitboard.full + boards.ep = Bitboard.zero + boards.castle = Bitboard.zero + boards.toMove = 1 + boards.hm = 0 + boards.moves = 0 + boards.material = 0 + + return setmetatable(boards, Board) +end + +function Board.fromFen(fen ) + local b = Board.new() + local i = 0 + local rank = 7 + local file = 0 + + while true do + i = i + 1 + local p = fen:sub(i,i) + if p == '/' then + rank = rank - 1 + file = 0 + elseif tonumber(p) ~= nil then + file = file + tonumber(p) + else + local pidx = PieceSymbols:find(p) + if pidx == nil then break end + b[pidx] = b[pidx]:set(rank*8+file, 1) + file = file + 1 + end + end + + + local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i) + if move == nil then print(fen:sub(i)) end + b.toMove = move == 'w' and 1 or 2 + + if ep ~= "-" then + b.ep = Bitboard.some(square(ep)) + end + + if castle ~= "-" then + local oo = Bitboard.zero + if castle:find("K") then + oo = oo:set(7, 1) + end + if castle:find("Q") then + oo = oo:set(0, 1) + end + if castle:find("k") then + oo = oo:set(63, 1) + end + if castle:find("q") then + oo = oo:set(56, 1) + end + + b.castle = oo + end + + b.hm = hm + b.moves = m + + b:updateCache() + return b + +end + +function Board:index(idx ) + if self.white:index(idx) == 1 then + for p=1,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + else + for p=2,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + end + + return 0 +end + +function Board:updateCache() + for i=1,11,2 do + self.white = self.white:bor(self[i]) + self.black = self.black:bor(self[i+1]) + end + + self.ocupied = self.black:bor(self.white) + self.unocupied = self.ocupied:inverse() + self.material = + 100*self[1]:popcnt() - 100*self[2]:popcnt() + + 500*self[3]:popcnt() - 500*self[4]:popcnt() + + 300*self[5]:popcnt() - 300*self[6]:popcnt() + + 300*self[7]:popcnt() - 300*self[8]:popcnt() + + 900*self[9]:popcnt() - 900*self[10]:popcnt() + +end + +function Board:fen() + local out = {} + local s = 0 + local idx = 56 + for i=0,63 do + if i % 8 == 0 and i > 0 then + idx = idx - 16 + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, '/') + end + local p = self:index(idx) + if p == 0 then + s = s + 1 + else + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, PieceSymbols:sub(p,p)) + end + + idx = idx + 1 + end + if s > 0 then + table.insert(out, '' .. s) + end + + table.insert(out, self.toMove == 1 and ' w ' or ' b ') + if self.castle:empty() then + table.insert(out, '-') + else + if self.castle:index(7) == 1 then table.insert(out, 'K') end + if self.castle:index(0) == 1 then table.insert(out, 'Q') end + if self.castle:index(63) == 1 then table.insert(out, 'k') end + if self.castle:index(56) == 1 then table.insert(out, 'q') end + end + + table.insert(out, ' ') + if self.ep:empty() then + table.insert(out, '-') + else + table.insert(out, squareName(self.ep:ctz())) + end + + table.insert(out, ' ' .. self.hm) + table.insert(out, ' ' .. self.moves) + + return table.concat(out) +end + +function Board:pmoves(idx) + return self:generate(idx) +end + +function Board:pcaptures(idx) + return self:generate(idx):band(self.ocupied) +end + +local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}} +local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}} +local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}} +local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}} + +function Board:generate(idx) + local piece = self:index(idx) + local r = Bitboard.some(idx) + local out = Bitboard.zero + local type = bit32.rshift(piece - 1, 1) + local cancapture = piece % 2 == 1 and self.black or self.white + + if piece == 0 then return Bitboard.zero end + + if type == 0 then + -- Pawn + local d = -(piece*2 - 3) + local movetwo = piece == 1 and Rank3 or Rank6 + + out = out:bor(r:move(0,d):band(self.unocupied)) + out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied)) + + local captures = r:move(0,d) + captures = captures:right():bor(captures:left()) + + if not captures:bandempty(self.ep) then + out = out:bor(self.ep) + end + + captures = captures:band(cancapture) + out = out:bor(captures) + + return out + elseif type == 5 then + -- King + for x=-1,1,1 do + for y = -1,1,1 do + local w = r:move(x,y) + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + end + elseif type == 2 then + -- Knight + for _,j in ipairs(KNIGHT_MOVES) do + local w = r:move(j[1],j[2]) + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + else + -- Sliders (Rook, Bishop, Queen) + local slides + if type == 1 then + slides = ROOK_SLIDES + elseif type == 3 then + slides = BISHOP_SLIDES + else + slides = QUEEN_SLIDES + end + + for _, op in ipairs(slides) do + local w = r + for i=1,7 do + w = w:move(op[1], op[2]) + if w:empty() then break end + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + break + end + end + end + end + + + return out +end + +-- 0-5 - From Square +-- 6-11 - To Square +-- 12 - is Check +-- 13 - Is EnPassent +-- 14 - Is Castle +-- 15-19 - Promotion Piece +-- 20-24 - Moved Pice +-- 25-29 - Captured Piece + + +function Board:toString(mark ) + local out = {} + for x=8,1,-1 do + table.insert(out, RANKS:sub(x,x) .. " ") + + for y=1,8 do + local n = 8*x+y-9 + local i = self:index(n) + if i == 0 then + table.insert(out, '-') + else + -- out = out .. PieceSymbols:sub(i,i) + table.insert(out, UnicodePieces[i]) + end + if mark ~= nil and mark:index(n) ~= 0 then + table.insert(out, ')') + elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then + table.insert(out, '(') + else + table.insert(out, ' ') + end + end + + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n") + return table.concat(out) +end + +function Board:moveList() + local tm = self.toMove == 1 and self.white or self.black + local castle_rank = self.toMove == 1 and Rank1 or Rank8 + local out = {} + local function emit(id) + if not self:applyMove(id):illegalyChecked() then + table.insert(out, id) + end + end + + local cr = tm:band(self.castle):band(castle_rank) + if not cr:empty() then + local p = self.toMove == 1 and 11 or 12 + local tcolor = self.toMove == 1 and self.black or self.white + local kidx = self[p]:ctz() + + + local castle = bit32.replace(0, p, 20, 4) + castle = bit32.replace(castle, kidx, 6, 6) + castle = bit32.replace(castle, 1, 14) + + + local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank) + local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p]) + if + not cr:bandempty(FileA) and + mustbeemptyl:bandempty(self.ocupied) and + not self:isSquareThreatened(cantbethreatened, tcolor) + then + emit(bit32.replace(castle, kidx - 2, 0, 6)) + end + + + local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank) + if + not cr:bandempty(FileH) and + mustbeemptyr:bandempty(self.ocupied) and + not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor) + then + emit(bit32.replace(castle, kidx + 2, 0, 6)) + end + end + + local sq = tm:ctz() + repeat + local p = self:index(sq) + local moves = self:pmoves(sq) + + while not moves:empty() do + local m = moves:ctz() + moves = moves:set(m, 0) + local id = bit32.replace(m, sq, 6, 6) + id = bit32.replace(id, p, 20, 4) + local mbb = Bitboard.some(m) + if not self.ocupied:bandempty(mbb) then + id = bit32.replace(id, self:index(m), 25, 4) + end + + -- Check if pawn needs to be promoted + if p == 1 and m >= 8*7 then + for i=3,9,2 do + emit(bit32.replace(id, i, 15, 4)) + end + elseif p == 2 and m < 8 then + for i=4,10,2 do + emit(bit32.replace(id, i, 15, 4)) + end + else + emit(id) + end + end + sq = tm:ctzafter(sq) + until sq == 64 + return out +end + +function Board:illegalyChecked() + local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")] + return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black) +end + +function Board:isSquareThreatened(target , color ) + local tm = color + local sq = tm:ctz() + repeat + local moves = self:pmoves(sq) + if not moves:bandempty(target) then + return true + end + sq = color:ctzafter(sq) + until sq == 64 + return false +end + +function Board:perft(depth ) + if depth == 0 then return 1 end + if depth == 1 then + return #self:moveList() + end + local result = 0 + for k,m in ipairs(self:moveList()) do + local c = self:applyMove(m):perft(depth - 1) + if c == 0 then + -- Perft only counts leaf nodes at target depth + -- result = result + 1 + else + result = result + c + end + end + return result +end + + +function Board:applyMove(move ) + local out = Board.new() + table.move(self, 1, 12, 1, out) + local from = bit32.extract(move, 6, 6) + local to = bit32.extract(move, 0, 6) + local promote = bit32.extract(move, 15, 4) + local piece = self:index(from) + local captured = self:index(to) + local tom = Bitboard.some(to) + local isCastle = bit32.extract(move, 14) + + if piece % 2 == 0 then + out.moves = self.moves + 1 + end + + if captured == 1 or piece < 3 then + out.hm = 0 + else + out.hm = self.hm + 1 + end + out.castle = self.castle + out.toMove = self.toMove == 1 and 2 or 1 + + if isCastle == 1 then + local rank = piece == 11 and Rank1 or Rank8 + local colorOffset = piece - 11 + + out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA) + out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank)) + + out[piece] = (from < to and FileG or FileC):band(rank) + out.castle = out.castle:bandnot(rank) + out:updateCache() + return out + end + + if piece < 3 then + local dist = math.abs(to - from) + -- Pawn moved two squares, set ep square + if dist == 16 then + out.ep = Bitboard.some((from + to) / 2) + end + + -- Remove enpasent capture + if not tom:bandempty(self.ep) then + if piece == 1 then + out[2] = out[2]:bandnot(self.ep:down()) + end + if piece == 2 then + out[1] = out[1]:bandnot(self.ep:up()) + end + end + end + + if piece == 3 or piece == 4 then + out.castle = out.castle:set(from, 0) + end + + if piece > 10 then + local rank = piece == 11 and Rank1 or Rank8 + out.castle = out.castle:bandnot(rank) + end + + out[piece] = out[piece]:set(from, 0) + if promote == 0 then + out[piece] = out[piece]:set(to, 1) + else + out[promote] = out[promote]:set(to, 1) + end + if captured ~= 0 then + out[captured] = out[captured]:set(to, 0) + end + + out:updateCache() + return out +end + +Board.__index = Board +Board.__tostring = Board.toString +-- +-- Main +-- + +local failures = 0 +local function test(fen, ply, target) + local b = Board.fromFen(fen) + if b:fen() ~= fen then + print("FEN MISMATCH", fen, b:fen()) + failures = failures + 1 + return + end + + local found = b:perft(ply) + if found ~= target then + print(fen, "Found", found, "target", target) + failures = failures + 1 + for k,v in pairs(b:moveList()) do + print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1')) + end + --error("Test Failure") + else + print("OK", found, fen) + end +end + +-- From https://www.chessprogramming.org/Perft_Results +-- If interpreter, computers, or algorithm gets too fast +-- feel free to go deeper + +local testCases = {} +local function addTest(...) table.insert(testCases, {...}) end + +addTest(StartingFen, 3, 8902) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) + + +local function chess() + for k,v in ipairs(testCases) do + test(v[1],v[2],v[3]) + end +end + +bench.runCode(chess, "chess") diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua index 41d97bb8b..ad0557b1d 100644 --- a/bench/tests/shootout/scimark.lua +++ b/bench/tests/shootout/scimark.lua @@ -30,7 +30,7 @@ ------------------------------------------------------------------------------ ------------------------------------------------------------------------------ --- Modificatin to be compatible with Lua 5.3 +-- Modification to be compatible with Lua 5.3 ------------------------------------------------------------------------------ local bench = script and require(script.Parent.bench_support) or require("bench_support") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 07910a0ac..44b8362df 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") local function target(a: number, b: string) return a + #b end local function d(a: n@1, b) - return target(a, b) + return target(a, b) end )"); @@ -1609,7 +1609,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: s@1) - return target(a, b) + return target(a, b) end )"); @@ -1622,7 +1622,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a:@1 @2, b) - return target(a, b) + return target(a, b) end )"); @@ -1640,7 +1640,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: @1)@2: number - return target(a, b) + return target(a, b) end )"); @@ -1682,7 +1682,7 @@ local x = target(function(a: n@1 local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end local x = target(function(a: n@1, b: @2) - return a + #b + return a + #b end) )"); @@ -1700,7 +1700,7 @@ end) local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a: n@1) - return a + return a end )"); @@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(...:n@1) - return a + return a end )"); @@ -1729,7 +1729,7 @@ end local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a:number, b:number, ...:@1) - return a + b + return a + b end )"); @@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") local function target(callback: () -> number) return callback() end local x = target(function(): n@1 - return 1 + return 1 end )"); @@ -1758,7 +1758,7 @@ end local function target(callback: () -> (number, number)) return callback() end local x = target(function(): (number, n@1 - return 1, 2 + return 1, 2 end )"); @@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" local function target(callback: () -> ...number) return callback() end local x = target(function(): ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); @@ -1787,7 +1787,7 @@ end local function target(callback: () -> ...number) return callback() end local x = target(function(): (number, number, ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 54a31a68a..bbac3302c 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf") local MaterialsListClass = {} function MaterialsListClass:_MakeToolTip(guiElement, text) - local function updateTooltipPosition() - self._tweakingTooltipFrame = 5 - end + local function updateTooltipPosition() + self._tweakingTooltipFrame = 5 + end - updateTooltipPosition() + updateTooltipPosition() end return MaterialsListClass @@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode") { CHECK_EQ("\n" + compileFunction(R"( function test() - for i=1,10 do + for i=1,10 do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2035,14 +2035,14 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - for i in ipairs(data) do + for i in ipairs(data) do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2068,17 +2068,17 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - while i < 5 do - local j + local i = 0 + while i < 5 do + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - end - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + end + return 0 end )", 1), @@ -2105,17 +2105,17 @@ RETURN R1 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - repeat - local j + local i = 0 + repeat + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - until i < 5 - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + until i < 5 + return 0 end )", 1), @@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ... local Table = {} Table.SubTable["Key"] = { - Key1 = Value1, - Key2 = Value2, - Key3 = Value3, - Key4 = true, + Key1 = Value1, + Key2 = Value2, + Key3 = Value3, + Key4 = true, } )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5a697a49e..06b3c5237 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression") runConformance("ifelseexpr.lua"); } +TEST_CASE("TagMethodError") +{ + ScopedFastFlag sff{"LuauCcallRestoreFix", true}; + + runConformance("tmerror.lua", [](lua_State* L) { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) { + CHECK(lua_isyieldable(L)); + }; + }); +} + TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index 9f8748993..e55b5b0c3 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -2,6 +2,9 @@ #pragma once #include +#include + +namespace std { inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { @@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) } template -std::ostream& operator<<(std::ostream& lhs, const std::optional& t) +auto operator<<(std::ostream& lhs, const std::optional& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types { if (t) return lhs << *t; else return lhs << "none"; } + +} // namespace std diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index a9ed139f1..37f1b60b8 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") { LintResult result = lint(R"(--!strict type InputData = { - id: number, - inputType: EnumItem, - inputState: EnumItem, - updated: number, - position: Vector3, - keyCode: EnumItem, - name: string + id: number, + inputType: EnumItem, + inputState: EnumItem, + updated: number, + position: Vector3, + keyCode: EnumItem, + name: string } )"); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 045f0230d..f580604ca 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") CHECK_EQ("{number | string}", toString(requireType("p"), {true})); } +/* + * We had a problem where all type aliases would be prototyped into a child scope that happened + * to have the same level. This caused a problem where, if a sibling function referred to that + * type alias in its type signature, it would erroneously be quantified away, even though it doesn't + * actually belong to the function. + * + * We solved this by ascribing a unique subLevel to each prototyped alias. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +{ + CheckResult result = check(R"( + --!strict + + local KeyPool = {} + + local function newkey(pool: KeyPool, index) + return {} + end + + function newKeyPool() + local pool = { + available = {} :: {Key}, + } + + return setmetatable(pool, KeyPool) + end + + export type KeyPool = typeof(newKeyPool()) + export type Key = typeof(newkey(newKeyPool(), 1)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * We keep a cache of type alias onto TypeVar to prevent infinite types from + * being constructed via recursive or corecursive aliases. We have to adjust + * the TypeLevels of those generic TypeVars so that the unifier doesn't think + * they have improperly leaked out of their scope. + */ +TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases") +{ + CheckResult result = check(R"( + type Array = {T} + type Exclude = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index e89747762..17e32e9f9 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -359,7 +359,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_corrrectly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6da33a08f..eabf7e65d 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local b: Vector2? = nil local a = b.X + b.Z diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3a04a18f5..581375a12 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } +TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +{ + ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; + ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; + + CheckResult result = check(R"( +local exports = {} +local nested = {} + +nested.name = function(t, k) + local a = t.x.y + return rawget(t, k) +end + +exports.nested = nested +return exports + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8bcb02424..419da8ad1 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,12 +9,13 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; TEST_SUITE_BEGIN("ProvisionalTests"); -// These tests check for behavior that differes from the final behavior we'd +// These tests check for behavior that differs from the final behavior we'd // like to have. They serve to document the current state of the typechecker. // When making future improvements, its very likely these tests will break and // will need to be replaced. @@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string old_expected = R"( function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )"; - CHECK_EQ(expected, decorateWithTypes(code)); + + const std::string expected = R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean'then + local a1:boolean=a + elseif a.fn()then + local a2:{fn:()->(a,b...)}=a + end + end + )"; + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ(expected, decorateWithTypes(code)); + else + CHECK_EQ(old_expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") @@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { - ScopedFastInt sffi{"LuauTarjanChildLimit", 50}; - ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50}; + ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; + ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; CheckResult result = check(R"LUA( local Result diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 31739cdc7..36dcaa959 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauOrPredicate) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index b7f0dc7b0..f1451a815 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") REQUIRE_EQ(indexer.indexType, typeChecker.numberType); - REQUIRE(nullptr != get(indexer.indexResultType)); + REQUIRE(nullptr != get(follow(indexer.indexResultType))); } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index b75878b7e..453817574 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement") TEST_CASE_FIXTURE(Fixture, "generic_function") { - CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)"); + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*typeChecker.numberType, *requireType("a")); @@ -406,7 +411,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string - for p in primes3() do print(p) end -- no errror + for p in primes3() do print(p) end -- no error )"); LUAU_REQUIRE_ERROR_COUNT(2, result); @@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") REQUIRE_EQ(2, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); std::vector fArgs = flatten(fType->argTypes).first; @@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") REQUIRE_EQ(6, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); } @@ -2549,7 +2554,7 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") --!strict local x = nil function f() g() end - -- make sure print(x) doen't get toposorted here, breaking the mutual block + -- make sure print(x) doesn't get toposorted here, breaking the mutual block function g() x = f end print(x) )"); @@ -2987,7 +2992,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indeces") +TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( local key @@ -3176,7 +3181,24 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK(acm->context == CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "ignored_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55, "" + end + + local a = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); } TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") @@ -3194,6 +3216,8 @@ TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); } TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") @@ -3823,10 +3847,10 @@ local T: any T = {} T.__index = T function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self + local self = {} + setmetatable(self, T) + self:construct(...) + return self end function T:construct(index) end @@ -4049,11 +4073,11 @@ function n:Clone() end local m = {} function m.a(x) - x:Clone() + x:Clone() end function m.b() - m.a(n) + m.a(n) end return m @@ -4374,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { - ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true); - // Simple direct arg to arg propagation CheckResult result = check(R"( type Table = { x: number, y: number } @@ -4385,7 +4407,7 @@ f(function(a) return a.x + a.y end) LUAU_REQUIRE_NO_ERRORS(result); - // An optional funciton is accepted, but since we already provide a function, nil can be ignored + // An optional function is accepted, but since we already provide a function, nil can be ignored result = check(R"( type Table = { x: number, y: number } local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end @@ -4413,7 +4435,7 @@ f(function(a: number, b, c) return c and a + b or b - a end) LUAU_REQUIRE_NO_ERRORS(result); - // Anonymous function has a varyadic pack + // Anonymous function has a variadic pack result = check(R"( type Table = { x: number, y: number } local function f(a: (Table) -> number) return a({x = 1, y = 2}) end @@ -4432,7 +4454,7 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); - // Infer from varyadic packs into elements + // Infer from variadic packs into elements result = check(R"( function f(a: (...number) -> number) return a(1, 2) end f(function(a, b) return a + b end) @@ -4440,7 +4462,7 @@ f(function(a, b) return a + b end) LUAU_REQUIRE_NO_ERRORS(result); - // Infer from varyadic packs into varyadic packs + // Infer from variadic packs into variadic packs result = check(R"( type Table = { x: number, y: number } function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end @@ -4662,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( @@ -4679,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1f4b63ef2..1192a8ac0 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -8,6 +8,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauQuantifyInPlace2); + using namespace Luau; struct TryUnifyFixture : Fixture @@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler}; + UnifierSharedState unifierState{&iceHandler}; + Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); + else + CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3e1dedd47..8dab2605b 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") std::vector applyArgs = flatten(applyType->argTypes).first; REQUIRE_EQ(3, applyArgs.size()); - const FunctionTypeVar* fType = get(applyArgs[0]); + const FunctionTypeVar* fType = get(follow(applyArgs[0])); REQUIRE(fType != nullptr); - const FunctionTypeVar* gType = get(applyArgs[1]); + const FunctionTypeVar* gType = get(follow(applyArgs[1])); REQUIRE(gType != nullptr); std::vector gArgs = flatten(gType->argTypes).first; @@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") { CheckResult result = check(R"( local _ = function():((...any)->(...any),()->()) - return function() end, function() end + return function() end, function() end end for y in _() do end diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 037144e2a..34c25a9fe 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -52,7 +52,7 @@ TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign") { CheckResult result = check(R"( local a:number = 10 @@ -63,7 +63,7 @@ TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign2") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign2") { CheckResult result = check(R"( local a:number? = 10 @@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") { - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = {x: number} type B = {} @@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") TEST_CASE_FIXTURE(Fixture, "optional_union_members") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = { a = { x = 1, y = 2 }, b = 3 } type A = typeof(a) @@ -259,8 +255,6 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a.foo(x:number, y:number) return x + y end @@ -276,8 +270,6 @@ local c = b.foo(1, 2) TEST_CASE_FIXTURE(Fixture, "optional_union_methods") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a:foo(x:number, y:number) return x + y end @@ -310,8 +302,6 @@ return f() TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local b: A? = { x = 2 } @@ -327,8 +317,6 @@ local d = b.y TEST_CASE_FIXTURE(Fixture, "optional_index_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -341,8 +329,6 @@ local b = a[1] TEST_CASE_FIXTURE(Fixture, "optional_call_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = (number) -> number local a: A? = function(a) return -a end @@ -355,8 +341,6 @@ local b = a(4) TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local a: A? = { x = 2 } @@ -378,8 +362,6 @@ a.x = 2 TEST_CASE_FIXTURE(Fixture, "optional_length_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -392,9 +374,6 @@ local b = #a TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index a679e3fd2..930c1a39b 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); } +TEST_CASE("tagging_tables") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar ttv{TableTypeVar{}}; + CHECK(!Luau::hasTag(&ttv, "foo")); + Luau::attachTag(&ttv, "foo"); + CHECK(Luau::hasTag(&ttv, "foo")); +} + +TEST_CASE("tagging_classes") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + CHECK(!Luau::hasTag(&base, "foo")); + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); +} + +TEST_CASE("tagging_subclasses") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + + CHECK(!Luau::hasTag(&base, "foo")); + CHECK(!Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); + CHECK(Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&derived, "bar"); + CHECK(!Luau::hasTag(&base, "bar")); + CHECK(Luau::hasTag(&derived, "bar")); +} + +TEST_CASE("tagging_functions") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypePackVar empty{TypePack{}}; + TypeVar ftv{FunctionTypeVar{&empty, &empty}}; + CHECK(!Luau::hasTag(&ftv, "foo")); + Luau::attachTag(&ftv, "foo"); + CHECK(Luau::hasTag(&ftv, "foo")); +} + +TEST_CASE("tagging_props") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + Property prop{}; + CHECK(!Luau::hasTag(prop, "foo")); + Luau::attachTag(prop, "foo"); + CHECK(Luau::hasTag(prop, "foo")); +} + TEST_SUITE_END(); diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index 79f8d9c2c..aac42c56f 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -319,7 +319,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 73c3833da..753296422 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -185,7 +185,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index fd4b4de1b..4263dfda7 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -277,7 +277,7 @@ do assert(getmetatable(o) == tt) -- create new objects during GC local a = 'xuxu'..(10+3)..'joao', {} - ___Glob = o -- ressurect object! + ___Glob = o -- resurrect object! newproxy(o) -- creates a new one with same metatable print(">>> closing state " .. "<<<\n") end diff --git a/tests/conformance/locals.lua b/tests/conformance/locals.lua index cbe5f92d0..2d8d004b8 100644 --- a/tests/conformance/locals.lua +++ b/tests/conformance/locals.lua @@ -117,7 +117,7 @@ if rawget(_G, "querytab") then local t = querytab(a) for k,_ in pairs(a) do a[k] = nil end - collectgarbage() -- restore GC and collect dead fiels in `a' + collectgarbage() -- restore GC and collect dead fields in `a' for i=0,t-1 do local k = querytab(a, i) assert(k == nil or type(k) == 'number' or k == 'alo') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 5e8b9398c..d5bca44f0 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -172,7 +172,7 @@ end a = nil --- testing implicit convertions +-- testing implicit conversions local a,b = '10', '20' assert(a*b == 200 and a+b == 30 and a-b == -10 and a/b == 0.5 and -b == -20) diff --git a/tests/conformance/pm.lua b/tests/conformance/pm.lua index 9a113964a..263759ac0 100644 --- a/tests/conformance/pm.lua +++ b/tests/conformance/pm.lua @@ -21,9 +21,9 @@ a,b = string.find('alo', '') assert(a == 1 and b == 0) a,b = string.find('a\0o a\0o a\0o', 'a', 1) -- first position assert(a == 1 and b == 1) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the middle assert(a == 5 and b == 7) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the middle assert(a == 9 and b == 11) a,b = string.find('a\0a\0a\0a\0\0ab', '\0ab', 2); -- finds at the end assert(a == 9 and b == 11); diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.lua new file mode 100644 index 000000000..1ad4dd16f --- /dev/null +++ b/tests/conformance/tmerror.lua @@ -0,0 +1,15 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes + +-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly +-- called via pcall. +-- This test is meant to detect a regression in handling errors inside a tag method + +local testtable = {} +setmetatable(testtable, { __index = function() error("Error") end }) + +pcall(function() + testtable.missingmethod() +end) + +return('OK') diff --git a/tools/gdb-printers.py b/tools/gdb-printers.py index c711c5e2e..017b9f95e 100644 --- a/tools/gdb-printers.py +++ b/tools/gdb-printers.py @@ -11,9 +11,9 @@ def to_string(self): return type.name + " [" + str(value) + "]" def match_printer(val): - type = val.type.strip_typedefs() - if type.name and type.name.startswith('Luau::Variant<'): - return VariantPrinter(val) - return None + type = val.type.strip_typedefs() + if type.name and type.name.startswith('Luau::Variant<'): + return VariantPrinter(val) + return None gdb.pretty_printers.append(match_printer) From 82d74e6f73fa8d9a81c4c932c668543c08c27597 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 18:12:39 -0800 Subject: [PATCH 03/14] Sync to upstream/release/504 --- .gitignore | 1 + Analysis/include/Luau/Error.h | 13 +- Analysis/include/Luau/FileResolver.h | 2 +- Analysis/include/Luau/Transpiler.h | 3 +- Analysis/include/Luau/TypeInfer.h | 8 +- Analysis/include/Luau/TypePack.h | 20 +- Analysis/include/Luau/TypeVar.h | 45 +- Analysis/include/Luau/Unifiable.h | 3 - Analysis/include/Luau/Unifier.h | 2 +- Analysis/src/Autocomplete.cpp | 46 +- Analysis/src/BuiltinDefinitions.cpp | 244 +------- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 22 +- Analysis/src/Error.cpp | 291 ++++++---- Analysis/src/Frontend.cpp | 9 +- Analysis/src/Linter.cpp | 10 +- Analysis/src/Module.cpp | 17 +- Analysis/src/Predicate.cpp | 4 - Analysis/src/RequireTracer.cpp | 8 +- Analysis/src/Substitution.cpp | 13 +- Analysis/src/ToString.cpp | 114 +--- Analysis/src/Transpiler.cpp | 117 +++- Analysis/src/TypeAttach.cpp | 41 +- Analysis/src/TypeInfer.cpp | 586 ++++++-------------- Analysis/src/TypePack.cpp | 8 +- Analysis/src/TypeVar.cpp | 70 ++- Analysis/src/Unifiable.cpp | 10 - Analysis/src/Unifier.cpp | 345 ++++++------ Ast/include/Luau/Parser.h | 1 - Ast/src/Parser.cpp | 51 +- CLI/Analyze.cpp | 27 +- CLI/Repl.cpp | 93 +++- CMakeLists.txt | 38 +- Compiler/include/Luau/Bytecode.h | 4 + Compiler/include/Luau/Compiler.h | 3 + Compiler/src/Compiler.cpp | 47 +- Makefile | 10 +- VM/include/lua.h | 4 + VM/src/lapi.cpp | 11 + VM/src/lbitlib.cpp | 42 ++ VM/src/lbuiltins.cpp | 54 +- VM/src/lgc.cpp | 164 +----- VM/src/lstrlib.cpp | 20 +- VM/src/ltable.cpp | 1 + VM/src/ltable.h | 1 - VM/src/ltablib.cpp | 8 - bench/tests/chess.lua | 76 +-- fuzz/luau.proto | 7 + fuzz/proto.cpp | 7 + fuzz/protoprint.cpp | 10 + tests/AstQuery.test.cpp | 1 - tests/Autocomplete.test.cpp | 3 - tests/Compiler.test.cpp | 126 +++++ tests/Conformance.test.cpp | 6 +- tests/Linter.test.cpp | 6 - tests/Parser.test.cpp | 43 +- tests/Predicate.test.cpp | 6 - tests/ToString.test.cpp | 9 - tests/Transpiler.test.cpp | 252 ++++++++- tests/TypeInfer.aliases.test.cpp | 3 - tests/TypeInfer.builtins.test.cpp | 8 - tests/TypeInfer.classes.test.cpp | 25 +- tests/TypeInfer.definitions.test.cpp | 3 - tests/TypeInfer.generics.test.cpp | 89 --- tests/TypeInfer.intersectionTypes.test.cpp | 39 ++ tests/TypeInfer.provisional.test.cpp | 7 +- tests/TypeInfer.refinements.test.cpp | 44 +- tests/TypeInfer.tables.test.cpp | 72 +++ tests/TypeInfer.test.cpp | 34 +- tests/TypeInfer.tryUnify.test.cpp | 5 - tests/TypeInfer.typePacks.cpp | 12 - tests/TypeInfer.unionTypes.test.cpp | 41 +- tests/TypeVar.test.cpp | 2 - tests/conformance/bitwise.lua | 16 + 73 files changed, 1737 insertions(+), 1846 deletions(-) diff --git a/.gitignore b/.gitignore index 0b2422ced..fa11b45b5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ ^default.prof* ^fuzz-* ^luau$ +/.vs diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index ac6f13e96..9ee750043 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -8,11 +8,20 @@ namespace Luau { +struct TypeError; struct TypeMismatch { - TypeId wantedType; - TypeId givenType; + TypeMismatch() = default; + TypeMismatch(TypeId wantedType, TypeId givenType); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error); + + TypeId wantedType = nullptr; + TypeId givenType = nullptr; + + std::string reason; + std::shared_ptr error; bool operator==(const TypeMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index a05ec5e91..9b74fc12d 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -53,7 +53,7 @@ struct FileResolver } // DEPRECATED APIS - // These are going to be removed with LuauNewRequireTracer + // These are going to be removed with LuauNewRequireTrace2 virtual bool moduleExists(const ModuleName& name) const = 0; virtual std::optional fromAstFragment(AstExpr* expr) const = 0; virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; diff --git a/Analysis/include/Luau/Transpiler.h b/Analysis/include/Luau/Transpiler.h index 817459fed..df01008ca 100644 --- a/Analysis/include/Luau/Transpiler.h +++ b/Analysis/include/Luau/Transpiler.h @@ -18,6 +18,7 @@ struct TranspileResult std::string parseError; // Nonempty if the transpile failed }; +std::string toString(AstNode* node); void dump(AstNode* node); // Never fails on a well-formed AST @@ -25,6 +26,6 @@ std::string transpile(AstStatBlock& ast); std::string transpileWithTypes(AstStatBlock& block); // Only fails when parsing fails -TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}); +TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}, bool withTypes = false); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9d62fef0b..306ac77d8 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -263,8 +263,6 @@ struct TypeChecker * */ TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); - // Removed by FFlag::LuauRankNTypes - TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location); // Replace any free types or type packs by `any`. // This is used when exporting types from modules, to make sure free types don't leak. @@ -298,8 +296,6 @@ struct TypeChecker // Produce a new free type var. TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); - TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false); - TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false); // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -326,10 +322,8 @@ struct TypeChecker TypePackId addTypePack(std::initializer_list&& ty); TypePackId freshTypePack(const ScopePtr& scope); TypePackId freshTypePack(TypeLevel level); - TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false); - TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false); - TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); + TypeId resolveType(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index d987d46ca..e72808da7 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAG(LuauAddMissingFollow) - namespace Luau { @@ -128,13 +126,10 @@ TypePack* asMutable(const TypePack* tp); template const T* get(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(tp->ty)); } @@ -142,13 +137,10 @@ const T* get(TypePackId tp) template T* getMutable(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(asMutable(tp)->ty)); } diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 9611e881f..6bd7932db 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) -LUAU_FASTFLAG(LuauAddMissingFollow) namespace Luau { @@ -413,13 +412,17 @@ bool maybeGeneric(const TypeId ty); struct SingletonTypes { - const TypeId nilType = &nilType_; - const TypeId numberType = &numberType_; - const TypeId stringType = &stringType_; - const TypeId booleanType = &booleanType_; - const TypeId threadType = &threadType_; - const TypeId anyType = &anyType_; - const TypeId errorType = &errorType_; + const TypeId nilType; + const TypeId numberType; + const TypeId stringType; + const TypeId booleanType; + const TypeId threadType; + const TypeId anyType; + const TypeId errorType; + const TypeId optionalNumberType; + + const TypePackId anyTypePack; + const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; @@ -427,14 +430,6 @@ struct SingletonTypes private: std::unique_ptr arena; - TypeVar nilType_; - TypeVar numberType_; - TypeVar stringType_; - TypeVar booleanType_; - TypeVar threadType_; - TypeVar anyType_; - TypeVar errorType_; - TypeId makeStringMetatable(); }; @@ -472,13 +467,10 @@ TypeVar* asMutable(TypeId ty); template const T* get(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&tv->ty); } @@ -486,13 +478,10 @@ const T* get(TypeId tv) template T* getMutable(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&asMutable(tv)->ty); } diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 10dbf3335..c2e07e466 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -63,12 +63,9 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); - Free(TypeLevel level, bool DEPRECATED_canBeGeneric); int index; TypeLevel level; - // Removed by FFlag::LuauRankNTypes - bool DEPRECATED_canBeGeneric = false; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 56632e33c..be0aadd05 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -87,7 +87,6 @@ struct Unifier void tryUnifyWithAny(TypePackId any, TypePackId ty); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - std::optional findMetatableEntry(TypeId type, std::string entry); public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" @@ -102,6 +101,7 @@ struct Unifier bool isNonstrictMode() const; void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); + void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType); [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 3c43c8086..1c94bb684 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,7 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) @@ -369,20 +368,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId while (iter != endIter) { - if (FFlag::LuauAddMissingFollow) - { - if (isNil(*iter)) - ++iter; - else - break; - } + if (isNil(*iter)) + ++iter; else - { - if (auto primTy = Luau::get(*iter); primTy && primTy->type == PrimitiveTypeVar::NilType) - ++iter; - else - break; - } + break; } if (iter == endIter) @@ -397,21 +386,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - if (FFlag::LuauAddMissingFollow) + if (isNil(*iter)) { - if (isNil(*iter)) - { - ++iter; - continue; - } - } - else - { - if (auto innerPrimTy = Luau::get(*iter); innerPrimTy && innerPrimTy->type == PrimitiveTypeVar::NilType) - { - ++iter; - continue; - } + ++iter; + continue; } autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); @@ -496,7 +474,7 @@ static bool canSuggestInferredType(ScopePtr scope, TypeId ty) return false; // No syntax for unnamed tables with a metatable - if (const MetatableTypeVar* mtv = get(ty)) + if (get(ty)) return false; if (const TableTypeVar* ttv = get(ty)) @@ -688,7 +666,7 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n TypeId expectedType = follow(*it); - if (const FunctionTypeVar* ftv = get(expectedType)) + if (get(expectedType)) return true; if (const IntersectionTypeVar* itv = get(expectedType)) @@ -1519,10 +1497,10 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName return {}; TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1550,7 +1528,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->commentLocations = std::move(result.commentLocations); TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index f6f2363c6..62a06a3cd 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,10 +8,7 @@ #include -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -185,25 +182,11 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; - TypeId stringType = typeChecker.stringType; - TypeId threadType = typeChecker.threadType; - TypeId anyType = typeChecker.anyType; TypeArena& arena = typeChecker.globalTypes; - TypeId optionalNumber = makeOption(typeChecker, arena, numberType); - TypeId optionalString = makeOption(typeChecker, arena, stringType); - TypeId optionalBoolean = makeOption(typeChecker, arena, booleanType); - - TypeId stringOrNumber = makeUnion(arena, {stringType, numberType}); - - TypePackId emptyPack = arena.addTypePack({}); TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneStringPack = arena.addTypePack({stringType}); TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - TypePackId oneAnyPack = arena.addTypePack({anyType}); - - TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); @@ -215,8 +198,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - TypeId stringToAnyMap = arena.addType(TableTypeVar{{}, TableIndexer(stringType, anyType), typeChecker.globalScope->level}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); @@ -236,8 +217,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); } - TypeId anyFunction = arena.addType(FunctionTypeVar{anyTypePack, anyTypePack}); - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); @@ -252,222 +231,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) - { - TableTypeVar::Props debugLib{ - {"info", {makeIntersection(arena, - { - arena.addType(FunctionTypeVar{arena.addTypePack({typeChecker.threadType, numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({anyFunction, stringType}), anyTypePack}), - })}}, - {"traceback", {makeIntersection(arena, - { - makeFunction(arena, std::nullopt, {optionalString, optionalNumber}, {stringType}), - makeFunction(arena, std::nullopt, {typeChecker.threadType, optionalString, optionalNumber}, {stringType}), - })}}, - }; - - assignPropDocumentationSymbols(debugLib, "@luau/global/debug"); - addGlobalBinding(typeChecker, "debug", - arena.addType(TableTypeVar{debugLib, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}), "@luau"); - - TableTypeVar::Props utf8Lib = { - {"char", {arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneStringPack})}}, // FIXME - {"charpattern", {stringType}}, - {"codes", {makeFunction(arena, std::nullopt, {stringType}, - {makeFunction(arena, std::nullopt, {stringType, numberType}, {numberType, numberType}), stringType, numberType})}}, - {"codepoint", - {arena.addType(FunctionTypeVar{arena.addTypePack({stringType, optionalNumber, optionalNumber}), listOfAtLeastOneNumber})}}, // FIXME - {"len", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {optionalNumber, numberType})}}, - {"offset", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {numberType})}}, - {"nfdnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - {"graphemes", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, - {makeFunction(arena, std::nullopt, {}, {numberType, numberType})})}}, - {"nfcnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - }; - - assignPropDocumentationSymbols(utf8Lib, "@luau/global/utf8"); - addGlobalBinding( - typeChecker, "utf8", arena.addType(TableTypeVar{utf8Lib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId optionalV = makeOption(typeChecker, arena, genericV); - - TypeId arrayOfV = arena.addType(TableTypeVar{{}, TableIndexer(numberType, genericV), typeChecker.globalScope->level}); - - TypePackId unpackArgsPack = arena.addTypePack(TypePack{{arrayOfV, optionalNumber, optionalNumber}}); - TypePackId unpackReturnPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypeId unpackFunc = arena.addType(FunctionTypeVar{{genericV}, {}, unpackArgsPack, unpackReturnPack}); - - TypeId packResult = arena.addType(TableTypeVar{ - TableTypeVar::Props{{"n", {numberType}}}, TableIndexer{numberType, numberType}, typeChecker.globalScope->level, TableState::Sealed}); - TypePackId packArgsPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypePackId packReturnPack = arena.addTypePack(TypePack{{packResult}}); - - TypeId comparator = makeFunction(arena, std::nullopt, {genericV, genericV}, {booleanType}); - TypeId optionalComparator = makeOption(typeChecker, arena, comparator); - - TypeId packFn = arena.addType(FunctionTypeVar(packArgsPack, packReturnPack)); - - TableTypeVar::Props tableLib = { - {"concat", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalString, optionalNumber, optionalNumber}, {stringType})}}, - {"insert", {makeIntersection(arena, {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV}, {}), - makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, genericV}, {})})}}, - {"maxn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"remove", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalNumber}, {optionalV})}}, - {"sort", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalComparator}, {})}}, - {"create", {makeFunction(arena, std::nullopt, {genericV}, {}, {numberType, optionalV}, {arrayOfV})}}, - {"find", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV, optionalNumber}, {optionalNumber})}}, - - {"unpack", {unpackFunc}}, // FIXME - {"pack", {packFn}}, - - // Lua 5.0 compat - {"getn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"foreach", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, - {mapOfKtoV, makeFunction(arena, std::nullopt, {genericK, genericV}, {})}, {})}}, - {"foreachi", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, makeFunction(arena, std::nullopt, {genericV}, {})}, {})}}, - - // backported from Lua 5.3 - {"move", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, numberType, numberType, arrayOfV}, {})}}, - - // added in Luau (borrowed from LuaJIT) - {"clear", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {})}}, - - {"freeze", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {mapOfKtoV})}}, - {"isfrozen", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(tableLib, "@luau/global/table"); - addGlobalBinding( - typeChecker, "table", arena.addType(TableTypeVar{tableLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TableTypeVar::Props coroutineLib = { - {"create", {makeFunction(arena, std::nullopt, {anyFunction}, {threadType})}}, - {"resume", {arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{threadType}, anyTypePack}), anyTypePack})}}, - {"running", {makeFunction(arena, std::nullopt, {}, {threadType})}}, - {"status", {makeFunction(arena, std::nullopt, {threadType}, {stringType})}}, - {"wrap", {makeFunction( - arena, std::nullopt, {anyFunction}, {anyType})}}, // FIXME this technically returns a function, but we can't represent this - // atm since it can be called with different arg types at different times - {"yield", {arena.addType(FunctionTypeVar{anyTypePack, anyTypePack})}}, - {"isyieldable", {makeFunction(arena, std::nullopt, {}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(coroutineLib, "@luau/global/coroutine"); - addGlobalBinding(typeChecker, "coroutine", - arena.addType(TableTypeVar{coroutineLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId genericT = arena.addType(GenericTypeVar{"T"}); - TypeId genericR = arena.addType(GenericTypeVar{"R"}); - - // assert returns all arguments - TypePackId assertArgs = arena.addTypePack({genericT, optionalString}); - TypePackId assertRets = arena.addTypePack({genericT}); - addGlobalBinding(typeChecker, "assert", arena.addType(FunctionTypeVar{assertArgs, assertRets}), "@luau"); - - addGlobalBinding(typeChecker, "print", arena.addType(FunctionTypeVar{anyTypePack, emptyPack}), "@luau"); - - addGlobalBinding(typeChecker, "type", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding(typeChecker, "typeof", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - - addGlobalBinding(typeChecker, "error", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {}), "@luau"); - - addGlobalBinding(typeChecker, "tostring", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding( - typeChecker, "tonumber", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {numberType}), "@luau"); - - addGlobalBinding( - typeChecker, "rawequal", makeFunction(arena, std::nullopt, {genericT, genericR}, {}, {genericT, genericR}, {booleanType}), "@luau"); - addGlobalBinding( - typeChecker, "rawget", makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK}, {genericV}), "@luau"); - addGlobalBinding(typeChecker, "rawset", - makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK, genericV}, {mapOfKtoV}), "@luau"); - - TypePackId genericTPack = arena.addTypePack({genericT}); - TypePackId genericRPack = arena.addTypePack({genericR}); - TypeId genericArgsToReturnFunction = arena.addType( - FunctionTypeVar{{genericT, genericR}, {}, arena.addTypePack(TypePack{{}, genericTPack}), arena.addTypePack(TypePack{{}, genericRPack})}); - - TypeId setfenvArgType = makeUnion(arena, {numberType, genericArgsToReturnFunction}); - TypeId setfenvReturnType = makeOption(typeChecker, arena, genericArgsToReturnFunction); - addGlobalBinding(typeChecker, "setfenv", makeFunction(arena, std::nullopt, {setfenvArgType, stringToAnyMap}, {setfenvReturnType}), "@luau"); - - TypePackId ipairsArgsTypePack = arena.addTypePack({arrayOfV}); - - TypeId ipairsNextFunctionType = arena.addType( - FunctionTypeVar{{genericK, genericV}, {}, arena.addTypePack({arrayOfV, numberType}), arena.addTypePack({numberType, genericV})}); - - // ipairs returns 'next, Array, 0' so we would need type-level primitives and change to - // again, we have a direct reference to 'next' because ipairs returns it - // ipairs(t: Array) -> ((Array) -> (number, V), Array, 0) - TypePackId ipairsReturnTypePack = arena.addTypePack(TypePack{{ipairsNextFunctionType, arrayOfV, numberType}}); - - // ipairs(t: Array) -> ((Array) -> (number, V), Array, number) - addGlobalBinding(typeChecker, "ipairs", arena.addType(FunctionTypeVar{{genericV}, {}, ipairsArgsTypePack, ipairsReturnTypePack}), "@luau"); - - TypePackId pcallArg0FnArgs = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId pcallArg0FnRet = arena.addTypePack(TypePackVar{GenericTypeVar{"R"}}); - TypeId pcallArg0 = arena.addType(FunctionTypeVar{pcallArg0FnArgs, pcallArg0FnRet}); - TypePackId pcallArgsTypePack = arena.addTypePack(TypePack{{pcallArg0}, pcallArg0FnArgs}); - - TypePackId pcallReturnTypePack = arena.addTypePack(TypePack{{booleanType}, pcallArg0FnRet}); - - // pcall(f: (A...) -> R..., args: A...) -> boolean, R... - addGlobalBinding(typeChecker, "pcall", - arena.addType(FunctionTypeVar{{}, {pcallArg0FnArgs, pcallArg0FnRet}, pcallArgsTypePack, pcallReturnTypePack}), "@luau"); - - // errors thrown by the function 'f' are propagated onto the function 'err' that accepts it. - // and either 'f' or 'err' are valid results of this xpcall - // if 'err' did throw an error, then it returns: false, "error in error handling" - // TODO: the above is not represented (nor representable) in the type annotation below. - // - // The real type of xpcall is as such: (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, - // R2...) - TypePackId genericAPack = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId genericR1Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R1"}}); - TypePackId genericR2Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R2"}}); - - TypeId genericE = arena.addType(GenericTypeVar{"E"}); - - TypeId xpcallFArg = arena.addType(FunctionTypeVar{genericAPack, genericR1Pack}); - TypeId xpcallErrArg = arena.addType(FunctionTypeVar{arena.addTypePack({genericE}), genericR2Pack}); - - TypePackId xpcallArgsPack = arena.addTypePack({{xpcallFArg, xpcallErrArg}, genericAPack}); - TypePackId xpcallRetPack = arena.addTypePack({{booleanType}, genericR1Pack}); // FIXME - - addGlobalBinding(typeChecker, "xpcall", - arena.addType(FunctionTypeVar{{genericE}, {genericAPack, genericR1Pack, genericR2Pack}, xpcallArgsPack, xpcallRetPack}), "@luau"); - - addGlobalBinding(typeChecker, "unpack", unpackFunc, "@luau"); - - TypePackId selectArgsTypePack = arena.addTypePack(TypePack{ - {stringOrNumber}, - anyTypePack // FIXME? select() is tricky. - }); - - addGlobalBinding(typeChecker, "select", arena.addType(FunctionTypeVar{selectArgsTypePack, anyTypePack}), "@luau"); - - // TODO: not completely correct. loadstring's return type should be a function or (nil, string) - TypeId loadstringFunc = arena.addType(FunctionTypeVar{anyTypePack, oneAnyPack}); - - addGlobalBinding(typeChecker, "loadstring", - makeFunction(arena, std::nullopt, {stringType, optionalString}, - { - makeOption(typeChecker, arena, loadstringFunc), - makeOption(typeChecker, arena, stringType), - }), - "@luau"); - - // a userdata object is "roughly" the same as a sealed empty table - // except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - // another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - // setmetatable. - // TODO: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - TypeId sealedTable = arena.addType(TableTypeVar(TableState::Sealed, typeChecker.globalScope->level)); - addGlobalBinding(typeChecker, "newproxy", makeFunction(arena, std::nullopt, {optionalBoolean}, {sealedTable}), "@luau"); - } - // next(t: Table, i: K | nil) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", @@ -475,8 +238,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = (FFlag::LuauRankNTypes ? arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}) - : getGlobalBinding(typeChecker, "next")); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); // NOTE we are missing 'i: K | nil' argument in the first return types' argument. @@ -711,7 +473,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0]; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1e91561a6..96703ef16 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,9 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) - namespace Luau { @@ -19,6 +16,8 @@ declare bit32: { bnot: (number) -> number, extract: (number, number, number?) -> number, replace: (number, number, number, number?) -> number, + countlz: (number) -> number, + countrz: (number) -> number, } declare math: { @@ -103,15 +102,6 @@ declare _VERSION: string declare function gcinfo(): number -)BUILTIN_SRC"; - -std::string getBuiltinDefinitionSource() -{ - std::string src = kBuiltinDefinitionLuaSrc; - - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - src += R"( declare function print(...: T...) declare function type(value: T): string @@ -208,10 +198,12 @@ std::string getBuiltinDefinitionSource() -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V - )"; - } - return src; +)BUILTIN_SRC"; + +std::string getBuiltinDefinitionSource() +{ + return kBuiltinDefinitionLuaSrc; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 04d91444a..46ff2c72a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -94,8 +94,23 @@ struct ErrorConverter { std::string operator()(const Luau::TypeMismatch& tm) const { - ToStringOptions opts; - return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'"; + std::string result = "Type '" + Luau::toString(tm.givenType) + "' could not be converted into '" + Luau::toString(tm.wantedType) + "'"; + + if (tm.error) + { + result += "\ncaused by:\n "; + + if (!tm.reason.empty()) + result += tm.reason + ". "; + + result += Luau::toString(*tm.error); + } + else if (!tm.reason.empty()) + { + result += "; " + tm.reason; + } + + return result; } std::string operator()(const Luau::UnknownSymbol& e) const @@ -478,9 +493,36 @@ struct InvalidNameChecker } }; +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType) + : wantedType(wantedType) + , givenType(givenType) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) + , error(std::make_shared(std::move(error))) +{ +} + bool TypeMismatch::operator==(const TypeMismatch& rhs) const { - return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType; + if (!!error != !!rhs.error) + return false; + + if (error && !(*error == *rhs.error)) + return false; + + return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason; } bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const @@ -690,130 +732,141 @@ bool containsParseErrorName(const TypeError& error) return Luau::visit(InvalidNameChecker{}, error.data); } -void copyErrors(ErrorVec& errors, struct TypeArena& destArena) +template +void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - auto clone = [&](auto&& ty) { return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); }; auto visitErrorData = [&](auto&& e) { - using T = std::decay_t; + copyError(e, destArena, seenTypes, seenTypePacks); + }; - if constexpr (false) - { - } - else if constexpr (std::is_same_v) - { - e.wantedType = clone(e.wantedType); - e.givenType = clone(e.givenType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.typeFun = clone(e.typeFun); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.expectedReturnType = clone(e.expectedReturnType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.superType = clone(e.superType); - e.subType = clone(e.subType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.optional = clone(e.optional); - } - else if constexpr (std::is_same_v) - { - e.type = clone(e.type); + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + { + e.wantedType = clone(e.wantedType); + e.givenType = clone(e.givenType); - for (auto& ty : e.missing) - ty = clone(ty); - } - else - static_assert(always_false_v, "Non-exhaustive type switch"); + if (e.error) + visit(visitErrorData, e.error->data); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.typeFun = clone(e.typeFun); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.expectedReturnType = clone(e.expectedReturnType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.superType = clone(e.superType); + e.subType = clone(e.subType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.optional = clone(e.optional); + } + else if constexpr (std::is_same_v) + { + e.type = clone(e.type); + + for (auto& ty : e.missing) + ty = clone(ty); + } + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +void copyErrors(ErrorVec& errors, TypeArena& destArena) +{ + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + auto visitErrorData = [&](auto&& e) { + copyError(e, destArena, seenTypes, seenTypePacks); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5e7af50c5..2f411274b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,11 +18,10 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau @@ -415,7 +414,7 @@ CheckResult Frontend::check(const ModuleName& name) // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel) + if (options.typecheckTwice) { ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; @@ -897,7 +896,7 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; const ModuleInfo* info = exprs.find(&pathExpr); - if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) + if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty())) return std::nullopt; return *info; @@ -914,7 +913,7 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - if (FFlag::LuauNewRequireTrace) + if (FFlag::LuauNewRequireTrace2) return frontend->sourceNodes.count(moduleName) != 0; else return frontend->fileResolver->moduleExists(moduleName); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index bff947a56..1a5b24fe2 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,9 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) -LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) - namespace Luau { @@ -1110,10 +1107,7 @@ class LintUnknownType : AstVisitor if (g && g->name == "type") { - if (FFlag::LuauLinterUnknownTypeVectorAware) - validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); - else - validateType(arg, {Kind_Primitive}, "primitive type"); + validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); } else if (g && g->name == "typeof") { @@ -2146,7 +2140,7 @@ class LintTableOperations : AstVisitor "wrap it in parentheses to silence"); } - if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + if (func->index == "move" && node->args.size >= 4) { // table.move(t, 0, _, _) if (isConstant(args[1], 0.0)) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 2fd958965..880ffd2e5 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) @@ -290,9 +289,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) for (TypePackId genericPack : t.genericPacks) ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ftv->tags = t.tags; - + ftv->tags = t.tags; ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); ftv->argNames = t.argNames; ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -319,12 +316,7 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; for (const auto& [name, prop] : t.props) - { - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; - } + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), @@ -379,10 +371,7 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 25e63bffe..848627cf8 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -3,8 +3,6 @@ #include "Luau/Ast.h" -LUAU_FASTFLAG(LuauOrPredicate) - namespace Luau { @@ -60,8 +58,6 @@ std::string toString(const LValue& lvalue) void merge(RefinementMap& l, const RefinementMap& r, std::function f) { - LUAU_ASSERT(FFlag::LuauOrPredicate); - auto itL = l.begin(); auto itR = r.begin(); while (itL != l.end() && itR != r.end()) diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 95910b562..b72f53f99 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,7 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) -LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false) namespace Luau { @@ -19,7 +19,7 @@ struct RequireTracerOld : AstVisitor : fileResolver(fileResolver) , currentModuleName(currentModuleName) { - LUAU_ASSERT(!FFlag::LuauNewRequireTrace); + LUAU_ASSERT(!FFlag::LuauNewRequireTrace2); } FileResolver* const fileResolver; @@ -188,7 +188,7 @@ struct RequireTracer : AstVisitor , currentModuleName(currentModuleName) , locals(nullptr) { - LUAU_ASSERT(FFlag::LuauNewRequireTrace); + LUAU_ASSERT(FFlag::LuauNewRequireTrace2); } bool visit(AstExprTypeAssertion* expr) override @@ -332,7 +332,7 @@ struct RequireTracer : AstVisitor RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - if (FFlag::LuauNewRequireTrace) + if (FFlag::LuauNewRequireTrace2) { RequireTraceResult result; RequireTracer tracer{result, fileResolver, currentModuleName}; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index d861eb3da..ca2b30f52 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,8 +8,6 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -19,7 +17,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { ty = follow(ty); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; if (const FunctionTypeVar* ftv = get(ty)) @@ -68,7 +66,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { tp = follow(tp); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; if (const TypePack* tpp = get(tp)) @@ -399,8 +397,7 @@ TypeId Substitution::clone(TypeId ty) if (FFlag::LuauTypeAliasPacks) clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - clone.tags = ttv->tags; + clone.tags = ttv->tags; result = addType(std::move(clone)); } else if (const MetatableTypeVar* mtv = get(ty)) @@ -486,7 +483,7 @@ void Substitution::replaceChildren(TypeId ty) { ty = follow(ty); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; if (FunctionTypeVar* ftv = getMutable(ty)) @@ -535,7 +532,7 @@ void Substitution::replaceChildren(TypePackId tp) { tp = follow(tp); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; if (TypePack* tpp = getMutable(tp)) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index cd8180dba..885fd489b 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -237,15 +236,6 @@ struct TypeVarStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tv)) - { - state.emit(state.getName(tv)); - return; - } - } - Luau::visit( [this, tv](auto&& t) { return (*this)(tv, t); @@ -316,11 +306,7 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; - - if (FFlag::LuauAddMissingFollow) - state.emit(state.getName(ty)); - else - state.emit(""); + state.emit(state.getName(ty)); } void operator()(TypeId, const BoundTypeVar& btv) @@ -724,16 +710,6 @@ struct TypePackStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tp)) - { - state.emit(state.getName(tp)); - state.emit("..."); - return; - } - } - auto it = state.cycleTpNames.find(tp); if (it != state.cycleTpNames.end()) { @@ -821,16 +797,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; - - if (FFlag::LuauAddMissingFollow) - { - state.emit(state.getName(tp)); - state.emit("..."); - } - else - { - state.emit(""); - } + state.emit(state.getName(tp)); + state.emit("..."); } void operator()(TypePackId, const BoundTypePack& btv) @@ -864,23 +832,15 @@ static void assignCycleNames(const std::unordered_set& cycles, const std std::string name; // TODO: use the stringified type list if there are no cycles - if (FFlag::LuauInstantiatedTypeParamRecursion) + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - { - // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) - cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; + // If we have a cycle type in type parameters, assign a cycle name for this named table + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) + cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; - continue; - } - } - else - { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - continue; + continue; } name = "t" + std::to_string(nextIndex); @@ -912,58 +872,6 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) ToStringResult result; - if (!FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) - { - if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) - { - if (ttv->syntheticName) - result.invalid = true; - - // If scope if provided, add module name and check visibility - if (ttv->name && opts.scope) - { - auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name); - - if (!success) - result.invalid = true; - - if (moduleName) - result.name = format("%s.", moduleName->c_str()); - } - - result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - - if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) - return result; - - std::vector params; - for (TypeId tp : ttv->instantiatedTypeParams) - params.push_back(toString(tp)); - - if (FFlag::LuauTypeAliasPacks) - { - // Doesn't preserve grouping of multiple type packs - // But this is under a parent block of code that is being removed later - for (TypePackId tp : ttv->instantiatedTypePackParams) - { - std::string content = toString(tp); - - if (!content.empty()) - params.push_back(std::move(content)); - } - } - - result.name += "<" + join(params, ", ") + ">"; - return result; - } - else if (auto mtv = get(ty); mtv && mtv->syntheticName) - { - result.invalid = true; - result.name = *mtv->syntheticName; - return result; - } - } - StringifierState state{opts, result, opts.nameMap}; std::unordered_set cycles; @@ -975,7 +883,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) TypeVarStringifier tvs{state}; - if (FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) + if (!opts.exhaustive) { if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) { diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1b83ccdc2..7d880af49 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace @@ -97,9 +96,6 @@ struct Writer { virtual ~Writer() {} - virtual void begin() {} - virtual void end() {} - virtual void advance(const Position&) = 0; virtual void newline() = 0; virtual void space() = 0; @@ -131,6 +127,7 @@ struct StringWriter : Writer if (pos.column < newPos.column) write(std::string(newPos.column - pos.column, ' ')); } + void maybeSpace(const Position& newPos, int reserve) override { if (pos.column + reserve < newPos.column) @@ -279,11 +276,14 @@ struct Printer writer.identifier(func->index.value); } - void visualizeTypePackAnnotation(const AstTypePack& annotation) + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { + advance(annotation.location.begin); if (const AstTypePackVariadic* variadicTp = annotation.as()) { - writer.symbol("..."); + if (!forVarArg) + writer.symbol("..."); + visualizeTypeAnnotation(*variadicTp->variadicType); } else if (const AstTypePackGeneric* genericTp = annotation.as()) @@ -293,6 +293,7 @@ struct Printer } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { + LUAU_ASSERT(!forVarArg); visualizeTypeList(explicitTp->typeList, true); } else @@ -317,7 +318,7 @@ struct Printer // Only variadic tail if (list.types.size == 0) { - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } else { @@ -345,7 +346,7 @@ struct Printer if (list.tailType) { writer.symbol(","); - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } writer.symbol(")"); @@ -542,6 +543,7 @@ struct Printer case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: writer.maybeSpace(a->right->location.begin, 2); + writer.symbol(toString(a->op)); break; case AstExprBinary::Concat: case AstExprBinary::CompareNe: @@ -550,19 +552,35 @@ struct Printer case AstExprBinary::CompareGe: case AstExprBinary::Or: writer.maybeSpace(a->right->location.begin, 3); + writer.keyword(toString(a->op)); break; case AstExprBinary::And: writer.maybeSpace(a->right->location.begin, 4); + writer.keyword(toString(a->op)); break; } - writer.symbol(toString(a->op)); - visualize(*a->right); } else if (const auto& a = expr.as()) { visualize(*a->expr); + + if (writeTypes) + { + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualize(*a->condition); + writer.keyword("then"); + visualize(*a->trueExpr); + writer.keyword("else"); + visualize(*a->falseExpr); } else if (const auto& a = expr.as()) { @@ -769,24 +787,31 @@ struct Printer switch (a->op) { case AstExprBinary::Add: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("+="); break; case AstExprBinary::Sub: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("-="); break; case AstExprBinary::Mul: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("*="); break; case AstExprBinary::Div: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; case AstExprBinary::Mod: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); break; case AstExprBinary::Pow: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("^="); break; case AstExprBinary::Concat: + writer.maybeSpace(a->value->location.begin, 3); writer.symbol("..="); break; default: @@ -874,7 +899,7 @@ struct Printer void visualizeFunctionBody(AstExprFunction& func) { - if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0)) + if (func.generics.size > 0 || func.genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -913,12 +938,13 @@ struct Printer if (func.vararg) { comma(); + advance(func.varargLocation.begin); writer.symbol("..."); if (func.varargAnnotation) { writer.symbol(":"); - visualizeTypePackAnnotation(*func.varargAnnotation); + visualizeTypePackAnnotation(*func.varargAnnotation, true); } } @@ -980,8 +1006,14 @@ struct Printer advance(typeAnnotation.location.begin); if (const auto& a = typeAnnotation.as()) { + if (a->hasPrefix) + { + writer.write(a->prefix.value); + writer.symbol("."); + } + writer.write(a->name.value); - if (a->parameters.size > 0) + if (a->parameters.size > 0 || a->hasParameterList) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -992,7 +1024,7 @@ struct Printer if (o.type) visualizeTypeAnnotation(*o.type); else - visualizeTypePackAnnotation(*o.typePack); + visualizeTypePackAnnotation(*o.typePack, false); } writer.symbol(">"); @@ -1000,7 +1032,7 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -1075,7 +1107,16 @@ struct Printer auto rta = r->as(); if (rta && rta->name == "nil") { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + writer.symbol("?"); return; } @@ -1089,7 +1130,15 @@ struct Printer writer.symbol("|"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (const auto& a = typeAnnotation.as()) @@ -1102,7 +1151,15 @@ struct Printer writer.symbol("&"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (typeAnnotation.is()) @@ -1116,31 +1173,27 @@ struct Printer } }; -void dump(AstNode* node) +std::string toString(AstNode* node) { StringWriter writer; + writer.pos = node->location.begin; + Printer printer(writer); printer.writeTypes = true; if (auto statNode = dynamic_cast(node)) - { printer.visualize(*statNode); - printf("%s\n", writer.str().c_str()); - } else if (auto exprNode = dynamic_cast(node)) - { printer.visualize(*exprNode); - printf("%s\n", writer.str().c_str()); - } else if (auto typeNode = dynamic_cast(node)) - { printer.visualizeTypeAnnotation(*typeNode); - printf("%s\n", writer.str().c_str()); - } - else - { - printf("Can't dump this node\n"); - } + + return writer.str(); +} + +void dump(AstNode* node) +{ + printf("%s\n", toString(node).c_str()); } std::string transpile(AstStatBlock& block) @@ -1149,6 +1202,7 @@ std::string transpile(AstStatBlock& block) Printer(writer).visualizeBlock(block); return writer.str(); } + std::string transpileWithTypes(AstStatBlock& block) { StringWriter writer; @@ -1158,7 +1212,7 @@ std::string transpileWithTypes(AstStatBlock& block) return writer.str(); } -TranspileResult transpile(std::string_view source, ParseOptions options) +TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { auto allocator = Allocator{}; auto names = AstNameTable{allocator}; @@ -1176,6 +1230,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options) if (!parseResult.root) return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; + if (withTypes) + return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root)}; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 49f8e0cac..11aa7b394 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) @@ -203,39 +202,23 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), std::nullopt, AstName("")); AstArray generics; - if (FFlag::LuauGenericFunctions) + generics.size = ftv.generics.size(); + generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + size_t numGenerics = 0; + for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { - generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); - size_t i = 0; - for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) - { - if (auto gtv = get(*it)) - generics.data[i++] = AstName(gtv->name.c_str()); - } - } - else - { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + generics.data[numGenerics++] = AstName(gtv->name.c_str()); } AstArray genericPacks; - if (FFlag::LuauGenericFunctions) - { - genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); - size_t i = 0; - for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) - { - if (auto gtv = get(*it)) - genericPacks.data[i++] = AstName(gtv->name.c_str()); - } - } - else + genericPacks.size = ftv.genericPacks.size(); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + size_t numGenericPacks = 0; + for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 38e2e5270..8fad1af91 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,29 +22,19 @@ LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) -LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) -LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) -LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) -LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -222,9 +212,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) - , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) - , anyTypePack(globalTypes.addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}, true})) - , errorTypePack(globalTypes.addTypePack(TypePackVar{Unifiable::Error{}})) + , optionalNumberType(singletonTypes.optionalNumberType) + , anyTypePack(singletonTypes.anyTypePack) + , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -251,10 +241,8 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (module.cyclic) moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); - else if (FFlag::LuauRankNTypes) - moduleScope->returnType = freshTypePack(moduleScope); else - moduleScope->returnType = DEPRECATED_freshTypePack(moduleScope, true); + moduleScope->returnType = freshTypePack(moduleScope); moduleScope->varargPack = anyTypePack; @@ -268,7 +256,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona checkBlock(moduleScope, *module.root); - if (get(FFlag::LuauAddMissingFollow ? follow(moduleScope->returnType) : moduleScope->returnType)) + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); @@ -326,7 +314,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) check(scope, *typealias); else if (auto global = program.as()) { - TypeId globalType = (FFlag::LuauRankNTypes ? resolveType(scope, *global->type) : resolveType(scope, *global->type, true)); + TypeId globalType = resolveType(scope, *global->type); Name globalName(global->name.value); currentModule->declaredGlobals[globalName] = globalType; @@ -494,7 +482,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std Name name = typealias->name.value; TypeId type = bindings[name].type; - if (get(FFlag::LuauAddMissingFollow ? follow(type) : type)) + if (get(follow(type))) { *asMutable(type) = ErrorTypeVar{}; reportError(TypeError{typealias->location, OccursCheckFailed{}}); @@ -607,26 +595,22 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; + expectedTypes.reserve(return_.list.size); - if (FFlag::LuauInferReturnAssertAssign) - { - expectedTypes.reserve(return_.list.size); - - TypePackIterator expectedRetCurr = begin(scope->returnType); - TypePackIterator expectedRetEnd = end(scope->returnType); + TypePackIterator expectedRetCurr = begin(scope->returnType); + TypePackIterator expectedRetEnd = end(scope->returnType); - for (size_t i = 0; i < return_.list.size; ++i) + for (size_t i = 0; i < return_.list.size; ++i) + { + if (expectedRetCurr != expectedRetEnd) { - if (expectedRetCurr != expectedRetEnd) - { - expectedTypes.push_back(*expectedRetCurr); - ++expectedRetCurr; - } - else if (auto expectedArgsTail = expectedRetCurr.tail()) - { - if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) - expectedTypes.push_back(vtp->ty); - } + expectedTypes.push_back(*expectedRetCurr); + ++expectedRetCurr; + } + else if (auto expectedArgsTail = expectedRetCurr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) + expectedTypes.push_back(vtp->ty); } } @@ -672,34 +656,30 @@ ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; + expectedTypes.reserve(assign.vars.size); - if (FFlag::LuauInferReturnAssertAssign) - { - expectedTypes.reserve(assign.vars.size); + ScopePtr moduleScope = currentModule->getModuleScope(); - ScopePtr moduleScope = currentModule->getModuleScope(); + for (size_t i = 0; i < assign.vars.size; ++i) + { + AstExpr* dest = assign.vars.data[i]; - for (size_t i = 0; i < assign.vars.size; ++i) + if (auto a = dest->as()) { - AstExpr* dest = assign.vars.data[i]; - - if (auto a = dest->as()) - { - // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later - expectedTypes.push_back(scope->lookup(a->local)); - } - else if (auto a = dest->as()) - { - // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList - if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) - expectedTypes.push_back(it->second.typeId); - else - expectedTypes.push_back(std::nullopt); - } + // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later + expectedTypes.push_back(scope->lookup(a->local)); + } + else if (auto a = dest->as()) + { + // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList + if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) + expectedTypes.push_back(it->second.typeId); else - { - expectedTypes.push_back(checkLValue(scope, *dest)); - } + expectedTypes.push_back(std::nullopt); + } + else + { + expectedTypes.push_back(checkLValue(scope, *dest)); } } @@ -715,7 +695,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) AstExpr* dest = assign.vars.data[i]; TypeId left = nullptr; - if (!FFlag::LuauInferReturnAssertAssign || dest->is() || dest->is()) + if (dest->is() || dest->is()) left = checkLValue(scope, *dest); else left = *expectedTypes[i]; @@ -751,11 +731,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (right) { - if (FFlag::LuauGenericFunctions && !maybeGeneric(left) && isGeneric(right)) - right = instantiate(scope, right, loc); - - if (!FFlag::LuauGenericFunctions && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && - get(FFlag::LuauAddMissingFollow ? follow(right) : right)) + if (!maybeGeneric(left) && isGeneric(right)) right = instantiate(scope, right, loc); // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry @@ -766,7 +742,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (!destTableTypeReceivingNil || !destTableTypeReceivingNil->indexer) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. - if (isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && !get(follow(right))) + if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) unify(left, anyType, loc); else unify(left, right, loc); @@ -815,7 +791,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (annotation) { - ty = (FFlag::LuauRankNTypes ? resolveType(scope, *annotation) : resolveType(scope, *annotation, true)); + ty = resolveType(scope, *annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(ty))) @@ -823,23 +799,19 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) } if (!ty) - ty = rhsIsTable ? (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)) - : isNonstrictMode() ? anyType : (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + ty = rhsIsTable ? freshType(scope) : isNonstrictMode() ? anyType : freshType(scope); varBindings.emplace_back(vars[i], Binding{ty, vars[i]->location}); variableTypes.push_back(ty); expectedTypes.push_back(ty); - if (FFlag::LuauGenericFunctions) - instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); - else - instantiateGenerics.push_back(annotation != nullptr && get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)); + instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); } if (local.values.size > 0) { - TypePackId variablePack = addTypePack(variableTypes, FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, true)); + TypePackId variablePack = addTypePack(variableTypes, freshTypePack(scope)); TypePackId valuePack = checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; @@ -979,8 +951,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { AstExprCall* exprCall = firstValue->as(); callRetPack = checkExprPack(scope, *exprCall).type; - if (!FFlag::LuauRankNTypes) - callRetPack = DEPRECATED_instantiate(scope, callRetPack, exprCall->location); callRetPack = follow(callRetPack); if (get(callRetPack)) @@ -998,8 +968,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else { iterTy = *first(callRetPack); - if (FFlag::LuauRankNTypes) - iterTy = instantiate(scope, iterTy, exprCall->location); + iterTy = instantiate(scope, iterTy, exprCall->location); } } else @@ -1158,10 +1127,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (FFlag::LuauGenericFunctions) - scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; - else - scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; + scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; } void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) @@ -1199,7 +1165,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; @@ -1234,7 +1200,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; } - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; @@ -1266,7 +1232,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } } - TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); + TypeId ty = resolveType(aliasScope, *typealias.type); if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy @@ -1325,25 +1291,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; - if (FFlag::LuauAddMissingFollow) + if (!get(follow(*superTy))) { - if (!get(follow(*superTy))) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); + reportError(declaredClass.location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); - return; - } - } - else - { - if (const ClassTypeVar* superCtv = get(*superTy); !superCtv) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); - - return; - } + return; } } @@ -1558,8 +1511,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa } else if (auto ftp = get(varargPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); - TypePackId tail = (FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); + TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; return {head}; } @@ -1567,7 +1520,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {errorType}; else if (auto vtp = get(varargPack)) return {vtp->ty}; - else if (FFlag::LuauGenericVariadicsUnification && get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1588,7 +1541,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa } else if (auto ftp = get(retPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); unify(retPack, pack, expr.location); return {head, std::move(result.predicates)}; @@ -1667,7 +1620,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId result = freshType(scope); tableType->props[name] = {result}; return result; } @@ -2129,7 +2082,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isNonstrictMode() && !isOrOp) return ty; - if (auto i = get(ty)) + if (get(ty)) { std::optional cleaned = tryStripUnionFromNil(ty); @@ -2158,16 +2111,9 @@ TypeId TypeChecker::checkRelationalOperation( { if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { - if (FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - { - ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); - return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); - } - else - { - return unionOfTypes(rhsType, checkExpr(scope, *subexp->right).type, expr.location); - } + ScopePtr subScope = childScope(scope, subexp->location); + reportErrors(resolve(predicates, subScope, true)); + return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); } } @@ -2217,10 +2163,8 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = - isString(lhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType); - std::optional rightMetatable = - isString(rhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(rhsType) : rhsType); + std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); + std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { @@ -2266,7 +2210,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) + if (get(follow(lhsType)) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2417,12 +2361,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi resolve(lhs.predicates, innerScope, true); ExprResult rhs = checkExpr(innerScope, *expr.right); - if (!FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - resolve(rhs.predicates, innerScope, true); return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauOrPredicate && expr.op == AstExprBinary::Or) + else if (expr.op == AstExprBinary::Or) { ExprResult lhs = checkExpr(scope, *expr.left); @@ -2468,19 +2410,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { - ExprResult result; - TypeId annotationType; - - if (FFlag::LuauInferReturnAssertAssign) - { - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - result = checkExpr(scope, *expr.expr, annotationType); - } - else - { - result = checkExpr(scope, *expr.expr); - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - } + TypeId annotationType = resolveType(scope, *expr.annotation); + ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); if (!errorVec.empty()) @@ -2570,23 +2501,16 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (it != moduleScope->bindings.end()) return std::pair(it->second.typeId, &it->second.typeId); - if (isNonstrictMode() || FFlag::LuauSecondTypecheckKnowsTheDataModel) - { - TypeId result = (FFlag::LuauGenericFunctions && FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(moduleScope, true)); - - Binding& binding = moduleScope->bindings[expr.name]; - binding = {result, expr.location}; + TypeId result = freshType(scope); + Binding& binding = moduleScope->bindings[expr.name]; + binding = {result, expr.location}; - // If we're in strict mode, we want to report defining a global as an error, - // but still add it to the bindings, so that autocomplete includes it in completions. - if (!isNonstrictMode()) - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); + // If we're in strict mode, we want to report defining a global as an error, + // but still add it to the bindings, so that autocomplete includes it in completions. + if (!isNonstrictMode()) + reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(result, &binding.typeId); - } - - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(errorType, nullptr); + return std::pair(result, &binding.typeId); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) @@ -2611,7 +2535,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - TypeId theType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; @@ -2683,7 +2607,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope AstExprConstantString* value = expr.index->as(); - if (value && FFlag::LuauClassPropertyAccessAsString) + if (value) { if (const ClassTypeVar* exprClass = get(exprType)) { @@ -2714,7 +2638,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; @@ -2730,7 +2654,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } @@ -2758,7 +2682,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(scope); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2768,7 +2692,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)), funName.location}; + binding = {freshType(scope), funName.location}; return binding.typeId; } @@ -2798,7 +2722,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + property.type = freshType(scope); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -2865,22 +2789,11 @@ std::pair TypeChecker::checkFunctionSignature( expectedFunctionType = nullptr; } - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); - } + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); TypePackId retPack; if (expr.hasReturnAnnotation) - { - if (FFlag::LuauGenericFunctions) - retPack = resolveTypePack(funScope, expr.returnAnnotation); - else - retPack = resolveTypePack(scope, expr.returnAnnotation); - } + retPack = resolveTypePack(funScope, expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) @@ -2889,24 +2802,17 @@ std::pair TypeChecker::checkFunctionSignature( // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) - retPack = FFlag::LuauGenericFunctions ? freshTypePack(funScope) : freshTypePack(scope); + retPack = freshTypePack(funScope); else retPack = addTypePack(head, tail); } - else if (FFlag::LuauGenericFunctions) - retPack = freshTypePack(funScope); else - retPack = freshTypePack(scope); + retPack = freshTypePack(funScope); if (expr.vararg) { if (expr.varargAnnotation) - { - if (FFlag::LuauGenericFunctions) - funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); - else - funScope->varargPack = resolveTypePack(scope, *expr.varargAnnotation); - } + funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); else { if (expectedFunctionType && !isNonstrictMode()) @@ -2963,7 +2869,7 @@ std::pair TypeChecker::checkFunctionSignature( if (local->annotation) { - argType = resolveType((FFlag::LuauGenericFunctions ? funScope : scope), *local->annotation); + argType = resolveType(funScope, *local->annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(argType))) @@ -3022,7 +2928,7 @@ static bool allowsNoReturnValues(const TypePackId tp) { for (TypeId ty : tp) { - if (!get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (!get(follow(ty))) { return false; } @@ -3058,7 +2964,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE check(scope, *function.body); // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (FFlag::LuauAddMissingFollow ? get_if(&funTy->retType->ty) : get(funTy->retType)) + if (get_if(&funTy->retType->ty)) *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3287,7 +3193,7 @@ void TypeChecker::checkArgumentList( return; } - else if (FFlag::LuauGenericVariadicsUnification && get(tail)) + else if (get(tail)) { // Create a type pack out of the remaining argument types // and unify it with the tail. @@ -3310,7 +3216,7 @@ void TypeChecker::checkArgumentList( return; } - else if (FFlag::LuauRankNTypes && get(tail)) + else if (get(tail)) { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; @@ -3323,10 +3229,7 @@ void TypeChecker::checkArgumentList( } else { - if (FFlag::LuauRankNTypes) - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); - else - state.tryUnify(*paramIter, *argIter, /*isFunctionCall*/ false); + unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); ++argIter; ++paramIter; } @@ -3356,9 +3259,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A ice("method call expression has no 'self'"); selfType = checkExpr(scope, *indexExpr->expr).type; - if (!FFlag::LuauRankNTypes) - instantiate(scope, selfType, expr.func->location); - selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) @@ -3393,8 +3293,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argList = argListResult.type; - TypePackId argPack = (FFlag::LuauRankNTypes ? argList : DEPRECATED_instantiate(scope, argList, expr.location)); + TypePackId argPack = argListResult.type; if (get(argPack)) return ExprResult{errorTypePack}; @@ -3526,8 +3425,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); TypeId fn = *ty; - if (FFlag::LuauRankNTypes) - fn = instantiate(scope, fn, expr.func->location); + fn = instantiate(scope, fn, expr.func->location); return checkCallOverload( scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); @@ -3800,7 +3698,7 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; - if (instantiateGenerics.size() > i && instantiateGenerics[i] && (FFlag::LuauGenericFunctions || get(actualType))) + if (instantiateGenerics.size() > i && instantiateGenerics[i]) actualType = instantiate(scope, actualType, expr->location); if (expectedType) @@ -3837,7 +3735,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); - if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + if (FFlag::LuauNewRequireTrace2 && moduleInfo.name.empty()) { if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { @@ -3922,7 +3820,6 @@ bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& locat bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) { - LUAU_ASSERT(FFlag::LuauRankNTypes); Unifier state = mkUnifier(location); unifyWithInstantiationIfNeeded(scope, left, right, state); @@ -3933,7 +3830,6 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (!maybeGeneric(right)) // Quick check to see if we definitely can't instantiate state.tryUnify(left, right, /*isFunctionCall*/ false); @@ -3973,19 +3869,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l bool Instantiation::isDirty(TypeId ty) { - if (FFlag::LuauRankNTypes) - { - if (get(ty)) - return true; - else - return false; - } - - if (const FunctionTypeVar* ftv = get(ty)) - return !ftv->generics.empty() || !ftv->genericPacks.empty(); - else if (const TableTypeVar* ttv = get(ty)) - return ttv->state == TableState::Generic; - else if (get(ty)) + if (get(ty)) return true; else return false; @@ -3993,18 +3877,11 @@ bool Instantiation::isDirty(TypeId ty) bool Instantiation::isDirty(TypePackId tp) { - if (FFlag::LuauRankNTypes) - return false; - - if (get(tp)) - return true; - else - return false; + return false; } bool Instantiation::ignoreChildren(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (get(ty)) return true; else @@ -4013,63 +3890,38 @@ bool Instantiation::ignoreChildren(TypeId ty) TypeId Instantiation::clean(TypeId ty) { - LUAU_ASSERT(isDirty(ty)); - - if (const FunctionTypeVar* ftv = get(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); - - if (FFlag::LuauRankNTypes) - { - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - replaceGenerics.level = level; - replaceGenerics.currentModule = currentModule; - replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); - replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); - - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - } - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else if (const TableTypeVar* ttv = get(ty)) - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - TypeId result = addType(std::move(clone)); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TypeId result = addType(FreeTypeVar{level}); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } + const FunctionTypeVar* ftv = get(ty); + LUAU_ASSERT(ftv); + + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + replaceGenerics.level = level; + replaceGenerics.currentModule = currentModule; + replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); + replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; } TypePackId Instantiation::clean(TypePackId tp) { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - return addTypePack(TypePackVar(FreeTypePack{level})); + LUAU_ASSERT(false); + return tp; } bool ReplaceGenerics::ignoreChildren(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (const FunctionTypeVar* ftv = get(ty)) // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. @@ -4083,7 +3935,6 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::isDirty(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (const TableTypeVar* ttv = get(ty)) return ttv->state == TableState::Generic; else if (get(ty)) @@ -4094,7 +3945,6 @@ bool ReplaceGenerics::isDirty(TypeId ty) bool ReplaceGenerics::isDirty(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (get(tp)) return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); else @@ -4255,21 +4105,6 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat } } -TypePackId TypeChecker::DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); - instantiation.level = scope->level; - instantiation.currentModule = currentModule; - std::optional instantiated = instantiation.substitute(ty); - if (instantiated.has_value()) - return *instantiated; - else - { - reportError(location, UnificationTooComplex{}); - return errorTypePack; - } -} - TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { anyification.anyType = anyType; @@ -4444,16 +4279,6 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } -TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) -{ - return DEPRECATED_freshType(scope->level, canBeGeneric); -} - -TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) -{ - return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); -} - std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4509,21 +4334,8 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) return addTypePack(TypePackVar(FreeTypePack(level))); } -TypePackId TypeChecker::DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric) -{ - return DEPRECATED_freshTypePack(scope->level, canBeGeneric); -} - -TypePackId TypeChecker::DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric) -{ - return addTypePack(TypePackVar(FreeTypePack(level, canBeGeneric))); -} - -TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation, bool DEPRECATED_canBeGeneric) +TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) { - if (DEPRECATED_canBeGeneric) - LUAU_ASSERT(!FFlag::LuauRankNTypes); - if (const auto& lit = annotation.as()) { std::optional tf; @@ -4668,11 +4480,11 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation std::optional tableIndexer; for (const auto& prop : table->props) - props[prop.name.value] = {resolveType(scope, *prop.type, DEPRECATED_canBeGeneric)}; + props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType, DEPRECATED_canBeGeneric), resolveType(scope, *indexer->resultType, DEPRECATED_canBeGeneric)); + resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4683,17 +4495,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { ScopePtr funcScope = childScope(scope, func->location); - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); - } - - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && (generics.size() > 0 || genericPacks.size() > 0)) - reportError(TypeError{annotation.location, GenericError{"generic function where only monotypes are allowed"}}); + auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); @@ -4716,16 +4518,13 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (auto typeOf = annotation.as()) { TypeId ty = checkExpr(scope, *typeOf->expr).type; - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && isGeneric(ty)) - reportError(TypeError{annotation.location, GenericError{"typeof produced a polytype where only monotypes are allowed"}}); return ty; } else if (const auto& un = annotation.as()) { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(UnionTypeVar{types}); } @@ -4733,7 +4532,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(IntersectionTypeVar{types}); } @@ -4919,9 +4718,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) { - // TODO: CLI-46926 it's a bad idea to rename the type whether we follow through the BoundTypeVar or not - TypeId target = FFlag::LuauFollowInTypeFunApply ? follow(instantiated) : instantiated; - + // TODO: CLI-46926 it's not a good idea to rename the type here + TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; TableTypeVar* ttv = getMutableTableType(target); @@ -5152,31 +4950,18 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { - if (FFlag::LuauOrPredicate) + if (!sense) { - if (!sense) - { - OrPredicate orP{ - {NotPredicate{std::move(andP.lhs)}}, - {NotPredicate{std::move(andP.rhs)}}, - }; - - return resolve(orP, errVec, refis, scope, !sense); - } + OrPredicate orP{ + {NotPredicate{std::move(andP.lhs)}}, + {NotPredicate{std::move(andP.rhs)}}, + }; - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + return resolve(orP, errVec, refis, scope, !sense); } - else - { - // And predicate is currently not resolvable when sense is false. 'not (a and b)' is synonymous with '(not a) or (not b)'. - // TODO: implement environment merging to permit this case. - if (!sense) - return; - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); - } + resolve(andP.lhs, errVec, refis, scope, sense); + resolve(andP.rhs, errVec, refis, scope, sense); } void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5207,58 +4992,41 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { - if (FFlag::LuauTypeGuardPeelsAwaySubclasses) - { - // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); - - // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. - if (!optionIsSubtype && targetIsSubtype) + // This by itself is not truly enough to determine that A is stronger than B or vice versa. + // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. + // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) + bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); + bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + + // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. + if (!optionIsSubtype && targetIsSubtype) + return sense ? isaP.ty : option; + + // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. + if (optionIsSubtype && !targetIsSubtype) + return sense ? std::optional(option) : std::nullopt; + + // If neither has any relationship, we only return A if sense is false. + if (!optionIsSubtype && !targetIsSubtype) + return sense ? std::nullopt : std::optional(option); + + // If both are subtypes, then we're in one of the two situations: + // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ + // 2. any <: Instance ∧ Instance <: any + // Right now, we have to look at the types to see if they were undecidables. + // By this point, we also know free tables are also subtypes and supertypes. + if (optionIsSubtype && targetIsSubtype) + { + // We can only have (any, Instance) because the rhs is never undecidable right now. + // So we can just return the right hand side immediately. + + // typeof(x) == "Instance" where x : any + auto ttv = get(option); + if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) return sense ? isaP.ty : option; - // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. - if (optionIsSubtype && !targetIsSubtype) - return sense ? std::optional(option) : std::nullopt; - - // If neither has any relationship, we only return A if sense is false. - if (!optionIsSubtype && !targetIsSubtype) - return sense ? std::nullopt : std::optional(option); - - // If both are subtypes, then we're in one of the two situations: - // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ - // 2. any <: Instance ∧ Instance <: any - // Right now, we have to look at the types to see if they were undecidables. - // By this point, we also know free tables are also subtypes and supertypes. - if (optionIsSubtype && targetIsSubtype) - { - // We can only have (any, Instance) because the rhs is never undecidable right now. - // So we can just return the right hand side immediately. - - // typeof(x) == "Instance" where x : any - auto ttv = get(option); - if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) - return sense ? isaP.ty : option; - - // typeof(x) == "Instance" where x : Instance - if (sense) - return isaP.ty; - } - } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (isSubclass(lctv, rctv) == sense) - return option; - - if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - - if (canUnify(option, isaP.ty, isaP.location).empty() == sense) + // typeof(x) == "Instance" where x : Instance + if (sense) return isaP.ty; } @@ -5457,7 +5225,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp TypePack* expectedPack = getMutable(expectedTypePack); LUAU_ASSERT(expectedPack); for (size_t i = 0; i < expectedLength; ++i) - expectedPack->head.push_back(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + expectedPack->head.push_back(freshType(scope)); unify(expectedTypePack, tp, location); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 68a16ef04..228b19267 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -97,7 +97,7 @@ TypePackIterator begin(TypePackId tp) TypePackIterator end(TypePackId tp) { - return FFlag::LuauAddMissingFollow ? TypePackIterator{} : TypePackIterator{nullptr}; + return TypePackIterator{}; } bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) @@ -203,7 +203,7 @@ TypePackId follow(TypePackId tp) size_t size(TypePackId tp) { - if (auto pack = get(FFlag::LuauAddMissingFollow ? follow(tp) : tp)) + if (auto pack = get(follow(tp))) return size(*pack); else return 0; @@ -216,7 +216,7 @@ bool finite(TypePackId tp) if (auto pack = get(tp)) return pack->tail ? finite(*pack->tail) : true; - if (auto pack = get(tp)) + if (get(tp)) return false; return true; @@ -227,7 +227,7 @@ size_t size(const TypePack& tp) size_t result = tp.head.size(); if (tp.tail) { - const TypePack* tail = get(FFlag::LuauAddMissingFollow ? follow(*tp.tail) : *tp.tail); + const TypePack* tail = get(follow(*tp.tail)); if (tail) result += size(*tail); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e82f7519d..cd447ca23 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,8 +19,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) @@ -42,7 +40,7 @@ TypeId follow(TypeId t) }; auto force = [](TypeId ty) { - if (auto ltv = FFlag::LuauAddMissingFollow ? get_if(&ty->ty) : get(ty)) + if (auto ltv = get_if(&ty->ty)) { TypeId res = ltv->thunk(); if (get(res)) @@ -296,7 +294,7 @@ bool maybeGeneric(TypeId ty) { ty = follow(ty); if (auto ftv = get(ty)) - return FFlag::LuauRankNTypes || ftv->DEPRECATED_canBeGeneric; + return true; else if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 @@ -545,15 +543,30 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); +static TypeVar nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true}; +static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true}; +static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; +static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; +static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; +static TypeVar anyType_{AnyTypeVar{}}; +static TypeVar errorType_{ErrorTypeVar{}}; +static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; + +static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; +static TypePackVar errorTypePack_{Unifiable::Error{}}; + SingletonTypes::SingletonTypes() - : arena(new TypeArena) - , nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true} - , numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true} - , stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true} - , booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true} - , threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true} - , anyType_{AnyTypeVar{}} - , errorType_{ErrorTypeVar{}} + : nilType(&nilType_) + , numberType(&numberType_) + , stringType(&stringType_) + , booleanType(&booleanType_) + , threadType(&threadType_) + , anyType(&anyType_) + , errorType(&errorType_) + , optionalNumberType(&optionalNumberType_) + , anyTypePack(&anyTypePack_) + , errorTypePack(&errorTypePack_) + , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; @@ -749,9 +762,9 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) { - if (const PrimitiveTypeVar* ptv = get(ty)) + if (get(ty)) formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (const AnyTypeVar* atv = get(ty)) + else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); } else @@ -902,19 +915,19 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } - else if (const AnyTypeVar* atv = get(ty)) + else if (get(ty)) { formatAppend(result, "AnyTypeVar %d", index); finishNodeLabel(ty); finishNode(); } - else if (const PrimitiveTypeVar* ptv = get(ty)) + else if (get(ty)) { formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); finishNodeLabel(ty); finishNode(); } - else if (const ErrorTypeVar* etv = get(ty)) + else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); finishNodeLabel(ty); @@ -994,7 +1007,7 @@ void StateDot::visitChildren(TypePackId tp, int index) finishNodeLabel(tp); finishNode(); } - else if (const Unifiable::Error* etp = get(tp)) + else if (get(tp)) { formatAppend(result, "ErrorTypePack %d", index); finishNodeLabel(tp); @@ -1372,24 +1385,6 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } -static std::vector DEPRECATED_filterMap(TypeId type, TypeIdPredicate predicate) -{ - std::vector result; - - if (auto utv = get(follow(type))) - { - for (TypeId option : utv) - { - if (auto out = predicate(follow(option))) - result.push_back(*out); - } - } - else if (auto out = predicate(follow(type))) - return {*out}; - - return result; -} - static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs"; @@ -1470,9 +1465,6 @@ std::optional> magicFunctionFormat( std::vector filterMap(TypeId type, TypeIdPredicate predicate) { - if (!FFlag::LuauTypeGuardPeelsAwaySubclasses) - return DEPRECATED_filterMap(type, predicate); - type = follow(type); if (auto utv = get(type)) diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index cef07833a..dc5546640 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" -LUAU_FASTFLAG(LuauRankNTypes) - namespace Luau { namespace Unifiable @@ -14,14 +12,6 @@ Free::Free(TypeLevel level) { } -Free::Free(TypeLevel level, bool DEPRECATED_canBeGeneric) - : index(++nextIndex) - , level(level) - , DEPRECATED_canBeGeneric(DEPRECATED_canBeGeneric) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); -} - int Free::nextIndex = 0; Generic::Generic() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 2539650a4..82f621b66 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,17 +14,15 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); -LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) -LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) namespace Luau { @@ -219,17 +217,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool *asMutable(subTy) = BoundTypeVar(superTy); } - if (!FFlag::LuauRankNTypes) - l->DEPRECATED_canBeGeneric &= r->DEPRECATED_canBeGeneric; - return; } - else if (l && r && FFlag::LuauGenericFunctions) + else if (l && r) { log(superTy); occursCheck(superTy, subTy); - if (!FFlag::LuauRankNTypes) - r->DEPRECATED_canBeGeneric &= l->DEPRECATED_canBeGeneric; r->level = min(r->level, l->level); *asMutable(superTy) = BoundTypeVar(subTy); return; @@ -240,7 +233,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (FFlag::LuauRankNTypes && rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(l->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -266,31 +259,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto leftGeneric = get(superTy); - if (FFlag::LuauRankNTypes && leftGeneric && !leftGeneric->level.subsumes(r->level)) - { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); - return; - } - - // This is the old code which is just wrong - auto wrongGeneric = get(subTy); // Guaranteed to be null - if (!FFlag::LuauRankNTypes && FFlag::LuauGenericFunctions && wrongGeneric && r->level.subsumes(wrongGeneric->level)) + if (leftGeneric && !leftGeneric->level.subsumes(r->level)) { - // This code is unreachable! Should we just remove it? // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } - // Check if we're unifying a monotype with a polytype - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !r->DEPRECATED_canBeGeneric && isGeneric(superTy)) - { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Failed to unify a polytype with a monotype"}}); - return; - } - if (!get(subTy)) { if (auto leftLevel = getMutableLevel(superTy)) @@ -333,6 +308,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // A | B <: T if A <: T and B <: T bool failed = false; std::optional unificationTooComplex; + std::optional firstFailedOption; size_t count = uv->options.size(); size_t i = 0; @@ -345,7 +321,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; else if (!innerState.errors.empty()) + { + // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' + if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type)) + firstFailedOption = {innerState.errors.front()}; + failed = true; + } if (i != count - 1) innerState.log.rollback(); @@ -358,7 +340,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (failed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (const UnionTypeVar* uv = get(superTy)) { @@ -425,14 +412,49 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (const IntersectionTypeVar* uv = get(superTy)) { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) + if (FFlag::LuauExtendedTypeMismatchError) + { + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; + } + + log.concat(std::move(innerState.log)); + } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + } + else { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + } } } else if (const IntersectionTypeVar* uv = get(subTy)) @@ -480,7 +502,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); @@ -773,8 +800,8 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = l->tail && get(FFlag::LuauAddMissingFollow ? follow(*l->tail) : *l->tail) != nullptr; - const bool rFreeTail = r->tail && get(FFlag::LuauAddMissingFollow ? follow(*r->tail) : *r->tail) != nullptr; + const bool lFreeTail = l->tail && get(follow(*l->tail)) != nullptr; + const bool rFreeTail = r->tail && get(follow(*r->tail)) != nullptr; if (lFreeTail && rFreeTail) tryUnify_(*l->tail, *r->tail); else if (lFreeTail) @@ -812,7 +839,7 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(*superIter) : *superIter)) + else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) { superIter.advance(); continue; @@ -887,24 +914,21 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal ice("passed non-function types to unifyFunction"); size_t numGenerics = lf->generics.size(); - if (FFlag::LuauGenericFunctions && numGenerics != rf->generics.size()) + if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); - if (FFlag::LuauGenericFunctions && numGenericPacks != rf->genericPacks.size()) + if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - if (FFlag::LuauGenericFunctions) - { - for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); - } + for (size_t i = 0; i < numGenerics; i++) + log.pushSeen(lf->generics[i], rf->generics[i]); CountMismatch::Context context = ctx; @@ -931,22 +955,19 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal tryUnify_(lf->retType, rf->retType); } - if (lf->definition && !rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !subTy->persistent)) + if (lf->definition && !rf->definition && !subTy->persistent) { rf->definition = lf->definition; } - else if (!lf->definition && rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !superTy->persistent)) + else if (!lf->definition && rf->definition && !superTy->persistent) { lf->definition = rf->definition; } ctx = context; - if (FFlag::LuauGenericFunctions) - { - for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); - } + for (int i = int(numGenerics) - 1; 0 <= i; i--) + log.popSeen(lf->generics[i], rf->generics[i]); } namespace @@ -1032,7 +1053,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, r->second.type); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1047,7 +1073,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, rt->indexer->indexResultType); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1083,7 +1114,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, lt->indexer->indexResultType); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1384,21 +1420,8 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio const auto& r = rt->props.find(it.first); if (r == rt->props.end()) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + if (isOptional(it.second.type)) + continue; missingPropertiesInSuper.push_back(it.first); @@ -1482,21 +1505,8 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio const auto& r = lt->props.find(it.first); if (r == lt->props.end()) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + if (isOptional(it.second.type)) + continue; extraPropertiesInSub.push_back(it.first); } @@ -1526,7 +1536,18 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse innerState.tryUnify_(lhs->table, rhs->table); innerState.tryUnify_(lhs->metatable, rhs->metatable); - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + if (FFlag::LuauExtendedTypeMismatchError) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); + } + else + { + checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + } log.concat(std::move(innerState.log)); } @@ -1613,10 +1634,34 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - tryUnify_(prop.type, singletonTypes.errorType); + + if (!FFlag::LuauExtendedClassMismatchError) + tryUnify_(prop.type, singletonTypes.errorType); } else - tryUnify_(prop.type, classProp->type); + { + if (FFlag::LuauExtendedClassMismatchError) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, classProp->type); + + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + innerState.log.rollback(); + } + } + else + { + tryUnify_(prop.type, classProp->type); + } + } } if (table->indexer) @@ -1649,45 +1694,24 @@ static void queueTypePack_DEPRECATED( while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = follow(a); if (seenTypePacks.count(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; } - else + else if (auto tp = get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } @@ -1698,45 +1722,24 @@ static void queueTypePack(std::vector& queue, DenseHashSet& while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = follow(a); if (seenTypePacks.find(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; } - else + else if (auto tp = get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } @@ -1990,33 +1993,6 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); } -std::optional Unifier::findMetatableEntry(TypeId type, std::string entry) -{ - type = follow(type); - - std::optional metatable = getMetatable(type); - if (!metatable) - return std::nullopt; - - TypeId unwrapped = follow(*metatable); - - if (get(unwrapped)) - return singletonTypes.anyType; - - const TableTypeVar* mtt = getTableType(unwrapped); - if (!mtt) - { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); - return std::nullopt; - } - - auto it = mtt->props.find(entry); - if (it != mtt->props.end()) - return it->second.type; - else - return std::nullopt; -} - void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; @@ -2168,7 +2144,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { for (const auto& ty : a->head) { - if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (auto f = get(follow(ty))) { occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); occursCheck(seen_DEPRECATED, seen, needle, f->retType); @@ -2207,6 +2183,17 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); } +void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) +{ + LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError); + + if (auto e = hasUnificationTooComplex(innerErrors)) + errors.push_back(*e); + else if (!innerErrors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); +} + void Unifier::ice(const std::string& message, const Location& location) { sharedState.iceHandler->ice(message, location); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 42c64dc92..39c7d9251 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -282,7 +282,6 @@ class Parser // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(); - std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 846bc0ba9..a1bad65ef 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,13 +10,13 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsParserFix, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) +LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau { @@ -957,7 +957,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial) { nextLexeme(); - AstExpr* expr = parsePrimaryExpr(/* asStatement= */ false); + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ FFlag::LuauFixAmbiguousErrorRecoveryInAssign); if (!isExprLValue(expr)) expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); @@ -995,7 +995,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + auto [generics, genericPacks] = parseGenericTypeList(); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1343,19 +1343,18 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = !(FFlag::LuauParseGenericFunctions && lexer.current().type == '<'); - - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + bool monomorphic = lexer.current().type != '<'; Lexeme begin = lexer.current(); - if (FFlag::LuauGenericFunctionsParserFix) - expectAndConsume('(', "function parameters"); - else - { - LUAU_ASSERT(begin.type == '('); - nextLexeme(); // ( - } + auto [generics, genericPacks] = parseGenericTypeList(); + + Lexeme parameterStart = lexer.current(); + + if (!FFlag::LuauParseGenericFunctionTypeBegin) + begin = parameterStart; + + expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; @@ -1366,7 +1365,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) if (lexer.current().type != ')') varargAnnotation = parseTypeList(params, names); - expectMatchAndConsume(')', begin, true); + expectMatchAndConsume(')', parameterStart, true); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; @@ -1585,7 +1584,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return {parseTableTypeAnnotation(), {}}; } - else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) + else if (lexer.current().type == '(' || lexer.current().type == '<') { return parseFunctionTypeAnnotation(allowPack); } @@ -2315,19 +2314,6 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeListIfFFlagParseGenericFunctions() -{ - if (FFlag::LuauParseGenericFunctions) - return Parser::parseGenericTypeList(); - AstArray generics; - AstArray genericPacks; - generics.size = 0; - generics.data = nullptr; - genericPacks.size = 0; - genericPacks.data = nullptr; - return std::pair(generics, genericPacks); -} - std::pair, AstArray> Parser::parseGenericTypeList() { TempVector names{scratchName}; @@ -2342,7 +2328,7 @@ std::pair, AstArray> Parser::parseGenericTypeList() while (true) { AstName name = parseName().name; - if (FFlag::LuauParseGenericFunctions && lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3) { seenPack = true; nextLexeme(); @@ -2379,15 +2365,12 @@ AstArray Parser::parseTypeParams() Lexeme begin = lexer.current(); nextLexeme(); - bool seenPack = false; while (true) { if (FFlag::LuauParseTypePackTypeParameters) { if (shouldParseTypePackAnnotation(lexer)) { - seenPack = true; - auto typePack = parseTypePackAnnotation(); if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them @@ -2399,8 +2382,6 @@ AstArray Parser::parseTypeParams() if (typePack) { - seenPack = true; - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them parameters.push_back({{}, typePack}); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index ed0552d74..9ab10aaf0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -34,8 +34,10 @@ static void report(ReportFormat format, const char* name, const Luau::Location& } } -static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error) +static void reportError(ReportFormat format, const Luau::TypeError& error) { + const char* name = error.moduleName.c_str(); + if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); else @@ -49,7 +51,10 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) { - Luau::CheckResult cr = frontend.check(name); + Luau::CheckResult cr; + + if (frontend.isDirty(name)) + cr = frontend.check(name); if (!frontend.getSourceModule(name)) { @@ -58,7 +63,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, name, error); + reportError(format, error); Luau::LintResult lr = frontend.lint(name); @@ -115,7 +120,12 @@ struct CliFileResolver : Luau::FileResolver { if (Luau::AstExprConstantString* expr = node->as()) { - Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; + if (!moduleExists(name)) + { + // fall back to .lua if a module with .luau doesn't exist + name = std::string(expr->value.data, expr->value.size) + ".lua"; + } return {{name}}; } @@ -236,8 +246,15 @@ int main(int argc, char** argv) if (isDirectory(argv[i])) { traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + // Look for .luau first and if absent, fall back to .lua + if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) + { + failed += !analyzeFile(frontend, name.c_str(), format, annotate); + } + else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + { failed += !analyzeFile(frontend, name.c_str(), format, annotate); + } }); } else diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4968d0800..5c904cca4 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -13,6 +13,17 @@ #include +#ifdef _WIN32 +#include +#include +#endif + +enum class CompileFormat +{ + Default, + Binary +}; + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -51,9 +62,13 @@ static int lua_require(lua_State* L) return finishrequire(L); lua_pop(L, 1); - std::optional source = readFile(name + ".lua"); + std::optional source = readFile(name + ".luau"); if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); + { + source = readFile(name + ".lua"); // try .lua if .luau doesn't exist + if (!source) + luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error + } // module needs to run in a new thread, isolated from the rest lua_State* GL = lua_mainthread(L); @@ -183,6 +198,11 @@ static std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); +#ifdef __EMSCRIPTEN__ + // nicer formatting for errors in web repl + error = "Error:" + error; +#endif + fprintf(stdout, "%s", error.c_str()); } @@ -190,6 +210,39 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } +#ifdef __EMSCRIPTEN__ +extern "C" +{ + const char* executeScript(const char* source) + { + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); + } +} +#endif + +// Excluded from emscripten compilation to avoid -Wunused-function errors. +#ifndef __EMSCRIPTEN__ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; @@ -366,7 +419,7 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static bool compileFile(const char* name) +static bool compileFile(const char* name, CompileFormat format) { std::optional source = readFile(name); if (!source) @@ -383,7 +436,15 @@ static bool compileFile(const char* name) Luau::compileOrThrow(bcb, *source); - printf("%s", bcb.dumpEverything().c_str()); + switch (format) + { + case CompileFormat::Default: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; + } return true; } @@ -408,7 +469,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile: compile input files and output resulting bytecode\n"); + printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); @@ -440,8 +501,19 @@ int main(int argc, char** argv) return 0; } - if (argc >= 2 && strcmp(argv[1], "--compile") == 0) + + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { + CompileFormat format = CompileFormat::Default; + + if (strcmp(argv[1], "--compile=binary") == 0) + format = CompileFormat::Binary; + +#ifdef _WIN32 + if (format == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + int failed = 0; for (int i = 2; i < argc; ++i) @@ -452,13 +524,15 @@ int main(int argc, char** argv) if (isDirectory(argv[i])) { traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str()); + if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) + failed += !compileFile(name.c_str(), format); + else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + failed += !compileFile(name.c_str(), format); }); } else { - failed += !compileFile(argv[i]); + failed += !compileFile(argv[i], format); } } @@ -511,5 +585,6 @@ int main(int argc, char** argv) return failed; } } +#endif diff --git a/CMakeLists.txt b/CMakeLists.txt index d6598f2a9..36014a983 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,17 +17,26 @@ add_library(Luau.VM STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) - add_executable(Luau.Analyze.CLI) + if(NOT EMSCRIPTEN) + add_executable(Luau.Analyze.CLI) + else() + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") + endif() # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) - set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + + if(NOT EMSCRIPTEN) + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + endif() endif() -if(LUAU_BUILD_TESTS) +if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) endif() + include(Sources.cmake) target_compile_features(Luau.Ast PUBLIC cxx_std_17) @@ -53,10 +62,6 @@ if(MSVC) else() list(APPEND LUAU_OPTIONS -Wall) # All warnings list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors - - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused - endif() endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) @@ -65,7 +70,10 @@ target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) - target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + + if(NOT EMSCRIPTEN) + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + endif() target_include_directories(Luau.Repl.CLI PRIVATE extern) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) @@ -74,10 +82,20 @@ if(LUAU_BUILD_CLI) target_link_libraries(Luau.Repl.CLI PRIVATE pthread) endif() - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + if(NOT EMSCRIPTEN) + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + endif() + + if(EMSCRIPTEN) + # declare exported functions to emscripten + target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions) + + # custom output directory for wasm + js file + set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau) + endif() endif() -if(LUAU_BUILD_TESTS) +if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 4b03ed1c7..71631d101 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -467,6 +467,10 @@ enum LuauBuiltinFunction // vector ctor LBF_VECTOR, + + // bit32.count + LBF_BIT32_COUNTLZ, + LBF_BIT32_COUNTRZ, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index f8d671588..4f88e602e 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -36,6 +36,9 @@ struct CompileOptions // global builtin to construct vectors; disabled by default const char* vectorLib = nullptr; const char* vectorCtor = nullptr; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals = nullptr; }; class CompileError : public std::exception diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7750a1d9f..9712f02f4 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -13,7 +13,9 @@ LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) +LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) +LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) namespace Luau { @@ -22,6 +24,7 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; +// TODO: Remove with LuauGenericSpecialGlobals static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; CompileError::CompileError(const Location& location, const std::string& message) @@ -1277,7 +1280,7 @@ struct Compiler { const Global* global = globals.find(expr->name); - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->special)); + return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -2465,9 +2468,10 @@ struct Compiler } else if (node->is()) { + LUAU_ASSERT(!loops.empty()); + // before exiting out of the loop, we need to close all local variables that were captured in closures since loop start // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here - LUAU_ASSERT(!loops.empty()); closeLocals(loops.back().localOffset); size_t label = bytecode.emitLabel(); @@ -2478,12 +2482,13 @@ struct Compiler } else if (AstStatContinue* stat = node->as()) { + LUAU_ASSERT(!loops.empty()); + if (loops.back().untilCondition) validateContinueUntil(stat, loops.back().untilCondition); // before continuing, we need to close all local variables that were captured in closures since loop start // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here - LUAU_ASSERT(!loops.empty()); closeLocals(loops.back().localOffset); size_t label = bytecode.emitLabel(); @@ -2900,6 +2905,11 @@ struct Compiler break; case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.valueString.size); + } break; default: @@ -3440,7 +3450,7 @@ struct Compiler struct Global { - bool special = false; + bool writable = false; bool written = false; }; @@ -3498,7 +3508,7 @@ struct Compiler { Global* g = globals.find(object->name); - return !g || (!g->special && !g->written) ? Builtin{object->name, expr->index} : Builtin(); + return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); } else { @@ -3629,6 +3639,10 @@ struct Compiler return LBF_BIT32_RROTATE; if (builtin.method == "rshift") return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin) + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin) + return LBF_BIT32_COUNTRZ; } if (builtin.object == "string") @@ -3696,13 +3710,24 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); - // since access to some global objects may result in values that change over time, we block table imports - for (const char* global : kSpecialGlobals) + // since access to some global objects may result in values that change over time, we block imports from non-readonly tables + if (FFlag::LuauGenericSpecialGlobals) { - AstName name = names.get(global); + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (name.value) - compiler.globals[name].special = true; + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + compiler.globals[name].writable = true; + } + else + { + for (const char* global : kSpecialGlobals) + { + if (AstName name = names.get(global); name.value) + compiler.globals[name].writable = true; + } } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written @@ -3717,7 +3742,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1) + if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Makefile b/Makefile index 7788251d8..5d51b3d4e 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,10 @@ OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(T CXXFLAGS=-g -Wall -Werror LDFLAGS= -CXXFLAGS+=-Wno-unused # temporary, for older gcc versions +# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused +ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) + CXXFLAGS+=-Wno-unused +endif # configuration-specific flags ifeq ($(config),release) @@ -134,12 +137,11 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): # executable targets for fuzzing fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) + $(CXX) $^ $(LDFLAGS) -o $@ + fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-%: - $(CXX) $^ $(LDFLAGS) -o $@ - # static library targets $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) diff --git a/VM/include/lua.h b/VM/include/lua.h index 2f93ad903..a9d3e875a 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -213,6 +213,8 @@ LUA_API int lua_resume(lua_State* L, lua_State* from, int narg); LUA_API int lua_resumeerror(lua_State* L, lua_State* from); LUA_API int lua_status(lua_State* L); LUA_API int lua_isyieldable(lua_State* L); +LUA_API void* lua_getthreaddata(lua_State* L); +LUA_API void lua_setthreaddata(lua_State* L, void* data); /* ** garbage-collection function and options @@ -346,6 +348,8 @@ struct lua_Debug * can only be changed when the VM is not running any code */ struct lua_Callbacks { + void* userdata; /* arbitrary userdata pointer that is never overwritten by Luau */ + void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */ void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */ diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f2e97c669..7e742644f 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -703,6 +703,7 @@ void lua_setreadonly(lua_State* L, int objindex, bool value) const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); + api_check(L, t != hvalue(registry(L))); t->readonly = value; return; } @@ -987,6 +988,16 @@ int lua_status(lua_State* L) return L->status; } +void* lua_getthreaddata(lua_State* L) +{ + return L->userdata; +} + +void lua_setthreaddata(lua_State* L, void* data) +{ + L->userdata = data; +} + /* ** Garbage-collection function */ diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 0754a351d..c72fe6748 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -4,6 +4,8 @@ #include "lnumutils.h" +LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) + #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -177,6 +179,44 @@ static int b_replace(lua_State* L) return 1; } +static int b_countlz(lua_State* L) +{ + if (!FFlag::LuauBit32Count) + luaL_error(L, "bit32.countlz isn't enabled"); + + b_uint v = luaL_checkunsigned(L, 1); + + b_uint r = NBITS; + for (int i = 0; i < NBITS; ++i) + if (v & (1u << (NBITS - 1 - i))) + { + r = i; + break; + } + + lua_pushunsigned(L, r); + return 1; +} + +static int b_countrz(lua_State* L) +{ + if (!FFlag::LuauBit32Count) + luaL_error(L, "bit32.countrz isn't enabled"); + + b_uint v = luaL_checkunsigned(L, 1); + + b_uint r = NBITS; + for (int i = 0; i < NBITS; ++i) + if (v & (1u << i)) + { + r = i; + break; + } + + lua_pushunsigned(L, r); + return 1; +} + static const luaL_Reg bitlib[] = { {"arshift", b_arshift}, {"band", b_and}, @@ -190,6 +230,8 @@ static const luaL_Reg bitlib[] = { {"replace", b_replace}, {"rrotate", b_rrot}, {"rshift", b_rshift}, + {"countlz", b_countlz}, + {"countrz", b_countrz}, {NULL, NULL}, }; diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index e1c99b21a..9ab57ac9b 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -20,8 +20,9 @@ // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path // If luauF_* succeeds, it needs to return *all* requested arguments, filling results with nil as appropriate. // On input, nparams refers to the actual number of arguments (0+), whereas nresults contains LUA_MULTRET for arbitrary returns or 0+ for a -// fixed-length return Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults -// is <= expected number, which covers the LUA_MULTRET case. +// fixed-length return +// Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults is <= expected +// number, which covers the LUA_MULTRET case. static int luauF_assert(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { @@ -1030,6 +1031,52 @@ static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_countlz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + unsigned n; + luai_num2unsigned(n, a1); + +#ifdef _MSC_VER + unsigned long rl; + int r = _BitScanReverse(&rl, n) ? 31 - int(rl) : 32; +#else + int r = n == 0 ? 32 : __builtin_clz(n); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + unsigned n; + luai_num2unsigned(n, a1); + +#ifdef _MSC_VER + unsigned long rl; + int r = _BitScanForward(&rl, n) ? int(rl) : 32; +#else + int r = n == 0 ? 32 : __builtin_ctz(n); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1096,4 +1143,7 @@ luau_FastFunction luauF_table[256] = { luauF_tunpack, luauF_vector, + + luauF_countlz, + luauF_countrz, }; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 6553009ff..11f79d1a3 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,7 +13,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -677,117 +676,6 @@ static size_t atomic(lua_State* L) return work; } -static size_t singlestep(lua_State* L) -{ - size_t cost = 0; - global_State* g = L->global; - switch (g->gcstate) - { - case GCSpause: - { - markroot(L); /* start a new collection */ - LUAU_ASSERT(g->gcstate == GCSpropagate); - break; - } - case GCSpropagate: - { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } - break; - } - case GCSpropagateagain: - { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } - } - break; - } - case GCSatomic: - { - g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - cost = atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - break; - } - case GCSsweepstring: - { - size_t traversedcount = 0; - sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); - - // nothing more to sweep? - if (g->sweepstrgc >= g->strt.size) - { - // sweep string buffer list and preserve used string count - uint32_t nuse = L->global->strt.nuse; - sweepwholelist(L, &g->strbufgc, &traversedcount); - L->global->strt.nuse = nuse; - - g->gcstate = GCSsweep; // end sweep-string phase - } - - g->gcstats.currcycle.sweepitems += traversedcount; - - cost = GC_SWEEPCOST; - break; - } - case GCSsweep: - { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); - - g->gcstats.currcycle.sweepitems += traversedcount; - - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } - cost = GC_SWEEPMAX * GC_SWEEPCOST; - break; - } - default: - LUAU_ASSERT(!"Unexpected GC state"); - } - - return cost; -} - static size_t gcstep(lua_State* L, size_t limit) { size_t cost = 0; @@ -980,37 +868,12 @@ void luaC_step(lua_State* L, bool assist) int lastgcstate = g->gcstate; double lasttimestamp = lua_clock(); - if (FFlag::LuauConsolidatedStep) - { - size_t work = gcstep(L, lim); + size_t work = gcstep(L, lim); - if (assist) - g->gcstats.currcycle.assistwork += work; - else - g->gcstats.currcycle.explicitwork += work; - } + if (assist) + g->gcstats.currcycle.assistwork += work; else - { - // always perform at least one single step - do - { - lim -= singlestep(L); - - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) - { - GC_INTERRUPT(lastgcstate); - - double now = lua_clock(); - - recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); - - lasttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); - } + g->gcstats.currcycle.explicitwork += work; recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); @@ -1037,14 +900,7 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - if (FFlag::LuauConsolidatedStep) - { - GC_INTERRUPT(lastgcstate); - } - else - { - GC_INTERRUPT(g->gcstate); - } + GC_INTERRUPT(lastgcstate); } void luaC_fullgc(lua_State* L) @@ -1070,10 +926,7 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - if (FFlag::LuauConsolidatedStep) - gcstep(L, SIZE_MAX); - else - singlestep(L); + gcstep(L, SIZE_MAX); } finishGcCycleStats(g); @@ -1084,10 +937,7 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - if (FFlag::LuauConsolidatedStep) - gcstep(L, SIZE_MAX); - else - singlestep(L); + gcstep(L, SIZE_MAX); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index a9db37272..80a34483a 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) + /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1404,10 +1406,20 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); + if (FFlag::LuauStrPackUBCastFix) + { + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); + } + else + { + unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, n, h.islittle, size, 0); + } break; } case Kfloat: diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 883442ae0..07d22d596 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -30,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) +static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect"); // TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); diff --git a/VM/src/ltable.h b/VM/src/ltable.h index f98d87b1d..45061443e 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -9,7 +9,6 @@ #define gval(n) (&(n)->val) #define gnext(n) ((n)->key.next) -static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast below is incorrect"); #define gval2slot(t, v) int(cast_to(LuaNode*, static_cast(v)) - t->node) LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index de5788eb6..370258189 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,8 +9,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -491,9 +489,6 @@ static int tclear(lua_State* L) static int tfreeze(lua_State* L) { - if (!FFlag::LuauTableFreeze) - luaG_runerror(L, "table.freeze is disabled"); - luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !lua_getreadonly(L, 1), 1, "table is already frozen"); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); @@ -506,9 +501,6 @@ static int tfreeze(lua_State* L) static int tisfrozen(lua_State* L) { - if (!FFlag::LuauTableFreeze) - luaG_runerror(L, "table.isfrozen is disabled"); - luaL_checktype(L, 1, LUA_TTABLE); lua_pushboolean(L, lua_getreadonly(L, 1)); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua index 87b9abfd4..f6ae2cc6b 100644 --- a/bench/tests/chess.lua +++ b/bench/tests/chess.lua @@ -205,38 +205,48 @@ function Bitboard:empty() return self.h == 0 and self.l == 0 end -function Bitboard:ctz() - local target = self.l - local offset = 0 - local result = 0 - - if target == 0 then - target = self.h - result = 32 +if not bit32.countrz then + local function ctz(v) + if v == 0 then return 32 end + local offset = 0 + while bit32.extract(v, offset) == 0 do + offset = offset + 1 + end + return offset end - - if target == 0 then - return 64 + function Bitboard:ctz() + local result = ctz(self.l) + if result == 32 then + return ctz(self.h) + 32 + else + return result + end end - - while bit32.extract(target, offset) == 0 do - offset = offset + 1 + function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 end - - return result + offset -end - -function Bitboard:ctzafter(start) - start = start + 1 - if start < 32 then - for i=start,31 do - if bit32.extract(self.l, i) == 1 then return i end +else + function Bitboard:ctz() + local result = bit32.countrz(self.l) + if result == 32 then + return bit32.countrz(self.h) + 32 + else + return result end end - for i=math.max(32,start),63 do - if bit32.extract(self.h, i-32) == 1 then return i end + function Bitboard:ctzafter(start) + local masked = self:band(Bitboard.full:lshift(start+1)) + return masked:ctz() end - return 64 end @@ -245,7 +255,7 @@ function Bitboard:lshift(amt) if amt == 0 then return self end if amt > 31 then - return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + return Bitboard.from(0, bit32.lshift(self.l, amt-32)) end local l = bit32.lshift(self.l, amt) @@ -832,12 +842,12 @@ end local testCases = {} local function addTest(...) table.insert(testCases, {...}) end -addTest(StartingFen, 3, 8902) -addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) -addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) -addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) -addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) -addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) +addTest(StartingFen, 2, 400) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46) local function chess() diff --git a/fuzz/luau.proto b/fuzz/luau.proto index 41a1d077f..c78fcf31c 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -19,6 +19,7 @@ message Expr { ExprTable table = 13; ExprUnary unary = 14; ExprBinary binary = 15; + ExprIfElse ifelse = 16; } } @@ -149,6 +150,12 @@ message ExprBinary { required Expr right = 3; } +message ExprIfElse { + required Expr cond = 1; + required Expr then = 2; + required Expr else = 3; +} + message LValue { oneof lvalue_oneof { ExprLocal local = 1; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 6c230b67f..c85fac7d3 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -11,6 +11,7 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" #include "Luau/ToString.h" +#include "Luau/Transpiler.h" #include "lua.h" #include "lualib.h" @@ -23,6 +24,7 @@ const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; const bool kFuzzTypes = true; +const bool kFuzzTranspile = true; static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); @@ -242,6 +244,11 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) } } + if (kFuzzTranspile && parseResult.root) + { + transpileWithTypes(*parseResult.root); + } + // run resulting bytecode if (kFuzzVM && bytecode.size()) { diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 2c861a553..e61b69365 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -476,6 +476,16 @@ struct ProtoToLuau print(expr.right()); } + void print(const luau::ExprIfElse& expr) + { + source += " if "; + print(expr.cond()); + source += " then "; + print(expr.then()); + source += " else "; + print(expr.else_()); + } + void print(const luau::LValue& expr) { if (expr.has_local()) diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index dd49e675e..aa53a92b2 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -45,7 +45,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { ScopedFastFlag sffs[] = { - {"LuauDontMutatePersistentFunctions", true}, {"LuauPersistDefinitionFileTypes", true}, }; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 44b8362df..8a7798f3a 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1287,9 +1287,6 @@ local e: (n: n@5 TEST_CASE_FIXTURE(ACFixture, "generic_types") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - check(R"( function f(a: T@1 local b: string = "don't trip" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index bbac3302c..7f03019c3 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -13,6 +13,7 @@ LUAU_FASTFLAG(LuauPreloadClosures) LUAU_FASTFLAG(LuauPreloadClosuresFenv) LUAU_FASTFLAG(LuauPreloadClosuresUpval) +LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -1168,6 +1169,17 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldStringLen") +{ + CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( +LOADN R0 6 +LOADN R1 0 +LOADN R2 1 +LOADN R3 1 +RETURN R0 4 +)"); +} + TEST_CASE("ConstantFoldCompare") { // ordered comparisons @@ -3659,4 +3671,118 @@ RETURN R0 0 )"); } +TEST_CASE("LuauGenericSpecialGlobals") +{ + const char* source = R"( +print() +Game.print() +Workspace.print() +_G.print() +game.print() +plugin.print() +script.print() +shared.print() +workspace.print() +)"; + + { + ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; + + // Check Roblox globals are here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +RETURN R0 0 +)"); + } + + ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; + + // Check Roblox globals are no longer here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R0 3 +CALL R0 0 0 +GETIMPORT R0 5 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R0 9 +CALL R0 0 0 +GETIMPORT R0 11 +CALL R0 0 0 +GETIMPORT R0 13 +CALL R0 0 0 +GETIMPORT R0 15 +CALL R0 0 0 +GETIMPORT R0 17 +CALL R0 0 0 +RETURN R0 0 +)"); + + // Check we can add them back + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + const char* mutableGlobals[] = {"Game", "Workspace", "game", "plugin", "script", "shared", "workspace", NULL}; + options.mutableGlobals = &mutableGlobals[0]; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +RETURN R0 0 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 06b3c5237..c1b790b9c 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -240,8 +240,6 @@ TEST_CASE("Math") TEST_CASE("Table") { - ScopedFastFlag sff("LuauTableFreeze", true); - runConformance("nextvar.lua"); } @@ -322,6 +320,8 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { + ScopedFastFlag sff("LuauBit32Count", true); + runConformance("bitwise.lua"); } @@ -359,6 +359,8 @@ TEST_CASE("PCall") TEST_CASE("Pack") { + ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; + runConformance("tpack.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 37f1b60b8..7ba40c503 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -479,10 +479,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff{"LuauLinterUnknownTypeVectorAware", true}; - - SourceModule sm; - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1400,8 +1396,6 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { - ScopedFastFlag sff("LuauLinterTableMoveZero", true); - LintResult result = lintTyped(R"( local t = {} local tt = {} diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index a80718e47..e3e6ce6d8 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixAmbiguousErrorRecoveryInAssign) + using namespace Luau; namespace @@ -625,10 +627,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_messages") )"), "Cannot have more than one table indexer"); - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauGenericFunctionsParserFix", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CHECK_EQ(getParseError(R"( type T = foo )"), @@ -1624,6 +1622,20 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") "statements"); CHECK(result3.errors.size() == 1); + + auto result4 = matchParseError(R"( + local t = {} + function f() return t end + t.x, (f) + ().y = 5, 6 + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + if (FFlag::LuauFixAmbiguousErrorRecoveryInAssign) + CHECK(result4.errors.size() == 1); + else + CHECK(result4.errors.size() == 5); } TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") @@ -1824,9 +1836,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( function f(...: a...) end @@ -1861,9 +1870,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") TEST_CASE_FIXTURE(Fixture, "generic_function_declaration_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( declare function f() )"); @@ -1953,12 +1959,7 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", "Expected '->' when parsing function type, got "); - { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - ScopedFastFlag luauGenericFunctionsParserFix{"LuauGenericFunctionsParserFix", true}; - - matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); - } + matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } TEST_SUITE_END(); @@ -2362,8 +2363,6 @@ type Fn = ( CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); } - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - try { parse(R"(type Fn = (any, string | number | ()) -> any)"); @@ -2397,8 +2396,6 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - try { parse(R"( @@ -2521,7 +2518,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); AstStat* stat = parse(R"( @@ -2534,4 +2530,9 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp index bb5a93c54..7081693e2 100644 --- a/tests/Predicate.test.cpp +++ b/tests/Predicate.test.cpp @@ -33,8 +33,6 @@ TEST_SUITE_BEGIN("Predicate"); TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"b", typeChecker.stringType}, {"c", typeChecker.numberType}, @@ -61,8 +59,6 @@ TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"a", typeChecker.stringType}, {"b", typeChecker.stringType}, @@ -89,8 +85,6 @@ TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"a", typeChecker.stringType}, {"b", typeChecker.numberType}, diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e18bf7cdd..b076e9ad7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -259,9 +259,6 @@ TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names") TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names_generic") { - ScopedFastFlag luauGenericFunctions{"LuauGenericFunctions", true}; - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - CheckResult result = check("local function f(n: number, ...: a...): (a...) return ... end"); LUAU_REQUIRE_NO_ERRORS(result); @@ -340,10 +337,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { - ScopedFastFlag sff[] = { - {"LuauGenericFunctions", true}, - }; - CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -468,8 +461,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_inters TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; - TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->name = "Table"; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index bfff60f9d..928c03a31 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -21,7 +21,7 @@ local function isPortal(element) return false end - return element.component==Core.Portal + return element.component == Core.Portal end )"; @@ -223,12 +223,24 @@ TEST_CASE("escaped_strings") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_2") +{ + const std::string code = R"( local s="\a\b\f\n\r\t\v\'\"\\" )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("binary_keywords") +{ + const std::string code = "local c = a0 ._ or b0 ._"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -364,10 +376,10 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") )"; std::string expected = R"( - local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:...string): (string,...number) + local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:string): (string,...number) end - local b:(...string)->(...number)=function(...:...string): ...number + local b:(...string)->(...number)=function(...:string): ...number end local c:()->()=function(): () @@ -400,4 +412,238 @@ TEST_CASE_FIXTURE(Fixture, "function_type_location") CHECK_EQ(expected, actual); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") +{ + std::string code = "local a = 5 :: number"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") +{ + ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); + + std::string code = "local a = if 1 then 2 else 3"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") +{ + fileResolver.source["game/A"] = R"( +export type Type = { a: number } +return {} + )"; + + std::string code = R"( +local Import = require(game.A) +local a: Import.Type + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + std::string code = R"( +type Packed = (T...)->(T...) +local a: Packed<> +local b: Packed<(number, string)> + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested") +{ + std::string code = "local a: ((number)->(string))|((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_2") +{ + std::string code = "local a: (number&string)|(string&boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") +{ + std::string code = "local a: nil | (string & number)"; + + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") +{ + std::string code = "local a: ((number)->(string))&((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested_2") +{ + std::string code = "local a: (number|string)&(string|boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_varargs") +{ + std::string code = "local function f(...) return ... end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") +{ + std::string code = "local a = {1, 2, 3} local b = a[2]"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_unary") +{ + std::string code = R"( +local a = 1 +local b = -1 +local c = true +local d = not c +local e = 'hello' +local d = #e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") +{ + std::string code = R"( +local a, b, c +repeat + if a then break end + if b then continue end +until c + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignmenr") +{ + std::string code = R"( +local a = 1 +a += 2 +a -= 3 +a *= 4 +a /= 5 +a %= 6 +a ^= 7 +a ..= ' - result' + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") +{ + std::string code = "a, b, c = 1, 2, 3"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") +{ + ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true); + + std::string code = R"( +local function foo(a: T, ...: S...) return 1 end +local f: (T, S...)->(number) = foo + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_reverse") +{ + std::string code = "local a: nil | number"; + + CHECK_EQ("local a: number?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple") +{ + std::string code = "for k,v in next,{}do print(k,v) end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") +{ + std::string code = "local a = f:-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") +{ + std::string code = "-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("(error-stat: (error-expr))", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_type") +{ + std::string code = "local a: "; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a:%error-type%", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_parse_error") +{ + std::string code = "local a = -"; + + auto result = transpile(code); + CHECK_EQ("", result.code); + CHECK_EQ("Expected identifier when parsing expression, got ", result.parseError); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_to_string") +{ + std::string code = "local a: string = 'hello'"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + REQUIRE(parseResult.root); + REQUIRE(parseResult.root->body.size == 1); + AstStatLocal* statLocal = parseResult.root->body.data[0]->as(); + REQUIRE(statLocal); + CHECK_EQ("local a: string = 'hello'", toString(statLocal)); + REQUIRE(statLocal->vars.size == 1); + AstLocal* local = statLocal->vars.data[0]; + REQUIRE(local->annotation); + CHECK_EQ("string", toString(local->annotation)); + REQUIRE(statLocal->values.size == 1); + AstExpr* expr = statLocal->values.data[0]; + CHECK_EQ("'hello'", toString(expr)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index f580604ca..c27f8083b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -247,9 +247,6 @@ TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") { - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 17e32e9f9..1e2eae147 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -444,8 +444,6 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local co = coroutine.create(function() end) )"); @@ -456,8 +454,6 @@ TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local function nifty(x, y) print(x, y) @@ -476,8 +472,6 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( --!nonstrict local function nifty(x, y) @@ -822,8 +816,6 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local f = math.sin local function g(x) return math.sin(x) end diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index eabf7e65d..1ff23fe69 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -232,8 +232,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class") TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() local x = 1 + c["BaseField"] @@ -244,8 +242,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() c["BaseField"] = 444 @@ -451,4 +447,25 @@ b.X = 2 -- real Vector2.X is also read-only CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[3])); } +TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") +{ + ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true}; + + CheckResult result = check(R"( +local function foo(v) + return v.X :: number + string.len(v.Y) +end + +local a: Vector2 +local b = foo +b(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'Vector2' could not be converted into '{- X: a, Y: string -}' +caused by: + Property 'Y' is not compatible. Type 'number' could not be converted into 'string')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 41e3e45ad..2652486be 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -171,9 +171,6 @@ TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") TEST_CASE_FIXTURE(Fixture, "declaring_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - loadDefinition(R"( declare function f(a: a, b: b): string declare function g(...: a...): b... diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 581375a12..de2f01544 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -13,8 +13,6 @@ TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function id(x:a): a return x @@ -27,8 +25,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_function") TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a): a return x @@ -41,10 +37,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function id(...: a...): (a...) return ... end local x: string, y: boolean = id("hi", true) @@ -56,8 +48,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -66,8 +56,6 @@ TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id @@ -79,7 +67,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -92,7 +79,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -104,8 +90,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t = {} t.m = function(x: a):a return x end @@ -117,8 +101,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t: { m: (number)->number } = { m = function(x:number) return x+1 end } local function id(x:a):a return x end @@ -129,8 +111,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x:a): a @@ -145,8 +125,6 @@ TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a local y: string = id("hi") @@ -159,8 +137,6 @@ TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local id2 local function id1(x:a):a @@ -179,8 +155,6 @@ TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } local x: T = { id = function(x:a):a return x end } @@ -192,8 +166,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") TEST_CASE_FIXTURE(Fixture, "generic_factories") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -215,10 +187,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_factories") TEST_CASE_FIXTURE(Fixture, "factories_of_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -241,7 +209,6 @@ TEST_CASE_FIXTURE(Fixture, "factories_of_generics") TEST_CASE_FIXTURE(Fixture, "infer_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function id(x) return x @@ -265,7 +232,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x @@ -289,7 +255,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x) @@ -304,7 +269,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -316,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -331,8 +294,6 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") TEST_CASE_FIXTURE(Fixture, "infer_generic_property") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauRankNTypes", true}; CheckResult result = check(R"( local t = {} t.m = function(x) return x end @@ -344,9 +305,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_property") TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f(g: (a)->a) local x: number = g(37) @@ -358,9 +316,6 @@ TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f() : (a)->a local function id(x:a):a return x end @@ -372,9 +327,6 @@ TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id(id) @@ -384,7 +336,6 @@ TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) -- this will only typecheck if we infer z: any @@ -406,7 +357,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) local z = y @@ -423,12 +373,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") { - ScopedFastFlag sffs[] = { - {"LuauGenericFunctions", true}, - {"LuauParseGenericFunctions", true}, - {"LuauRankNTypes", true}, - }; - CheckResult result = check(R"( type T = { m: (a) -> T } function f(t : T) @@ -440,10 +384,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") TEST_CASE_FIXTURE(Fixture, "dont_unify_bound_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type F = () -> (a, b) -> a type G = (b, b) -> b @@ -470,7 +410,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") // Replaying the classic problem with polymorphism and mutable state in Luau // See, e.g. Tofte (1990) // https://www.sciencedirect.com/science/article/pii/089054019090018D. - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( --!strict -- Our old friend the polymorphic identity function @@ -508,7 +447,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") { - ScopedFastFlag sffs{"LuauGenericFunctions", false}; CheckResult result = check(R"( --!strict local function id(x) return x end @@ -531,8 +469,6 @@ TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f(x:a):a return x end )"); @@ -541,7 +477,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -550,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -559,9 +493,6 @@ TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") TEST_CASE_FIXTURE(Fixture, "variadic_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a) end @@ -573,9 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_generics") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): (a...) return ... end )"); @@ -586,10 +514,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): any return (...) end )"); @@ -599,9 +523,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: T...) return ... @@ -626,9 +547,6 @@ TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f() end )"); @@ -641,8 +559,6 @@ TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(z) local o = {} @@ -665,8 +581,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(x) return {5} end function g(x, y) return f(x) end @@ -680,8 +594,6 @@ TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( type T = { x: {a}, y: {number} } local o1: T = { x = {true}, y = {5} } @@ -697,7 +609,6 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") { - ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; CheckResult result = check(R"( diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 9685f4f35..893bc2b30 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -341,4 +341,43 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X & Y & Z + +local a: XYZ = 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' +caused by: + Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X & Y & Z + +local a: XYZ +local b: number = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 419da8ad1..e5c14dde8 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -194,9 +194,6 @@ TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } @@ -596,11 +593,9 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - ScopedFastFlag luauFollowInTypeFunApply{"LuauFollowInTypeFunApply", true}; - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; // Mutability in type function application right now can create strange recursive types - // TODO: instantiation right now is problematic, it this example should either leave the Table type alone + // TODO: instantiation right now is problematic, in this example should either leave the Table type alone // or it should rename the type to 'Self' so that the result will be 'Self' CheckResult result = check(R"( type Table = { a: number } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 36dcaa959..733fc39b3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauOrPredicate) LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -133,11 +132,8 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - if (FFlag::LuauOrPredicate) - { - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); } TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") @@ -283,6 +279,8 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -293,7 +291,10 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| x: number? |}' could not be converted into '{| x: number |}'", toString(result.errors[0])); + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") @@ -749,8 +750,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) == "Instance" then @@ -769,8 +768,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; - CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) if typeof(x) == "Instance" then @@ -789,8 +786,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( --!nonstrict @@ -811,11 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) ~= "Instance" or not x:IsA("Part") then @@ -890,8 +880,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_s TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if (not a) or (not b) then @@ -909,8 +897,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if not (a and b) then @@ -928,8 +914,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if (not a) and (not b) then @@ -947,8 +931,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if not (a or b) then @@ -966,8 +948,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(x: any) if type(x) == "number" or type(x) == "string" then @@ -983,8 +963,6 @@ TEST_CASE_FIXTURE(Fixture, "either_number_or_string") TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(t: {x: boolean}?) if not t or t.x then @@ -1000,8 +978,6 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local a: (number | string)? assert(a) @@ -1018,8 +994,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( local function f(b: string | { x: string }, a) @@ -1039,8 +1013,6 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: string | number | boolean) if type(a) ~= "number" and type(a) ~= "string" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index f1451a815..c3694be79 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1950,4 +1950,76 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type A = { x: number, y: number } +type B = { x: number, y: string } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type AS = { x: number, y: number } +type BS = { x: number, y: string } + +type A = { a: boolean, b: AS } +type B = { a: boolean, b: BS } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); +local c1: typeof(a1) = b1 + +local a2 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b2 = setmetatable({ x = 2, y = 4 }, { __call = function(s, t) end }); +local c2: typeof(a2) = b2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' +caused by: + Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' +caused by: + Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); + + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 453817574..30d9130a5 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3926,8 +3926,6 @@ local b: number = 1 or a TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - CheckResult result = check(R"( --!strict local tbl = {} @@ -4493,10 +4491,6 @@ f(function(x) print(x) end) TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4525,10 +4519,6 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function g1(a: T, f: (T) -> T) return f(a) end local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end @@ -4579,10 +4569,6 @@ local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end @@ -4600,8 +4586,6 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f(): {string|number} return {1, "b", 3} @@ -4625,8 +4609,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f() return {4, "b", 3} :: {string|number} @@ -4638,8 +4620,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a: (number, number) -> number = function(a, b) return a - b end @@ -4655,8 +4635,6 @@ b, c = {2, "s"}, {"b", 4} TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a = {} a.x = 2 @@ -4668,8 +4646,6 @@ a = setmetatable(a, { __call = function(x) end }) TEST_CASE_FIXTURE(Fixture, "refine_and_or") { - ScopedFastFlag sff{"LuauSlightlyMoreFlexibleBinaryPredicates", true}; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t and t.x or 5 @@ -4682,10 +4658,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - }; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t.x and t or 5 @@ -4698,10 +4670,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - }; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t and t.x == 5 or t.x == 31337 @@ -4714,7 +4682,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { - ScopedFastFlag luauFollowInTypeFunApply("LuauFollowInTypeFunApply", true); + ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; CheckResult result = check(R"( type A = { x: number } diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1192a8ac0..2d697fc97 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -178,9 +178,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_wor TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( --!strict local function f(...: T): ...T @@ -199,8 +196,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") { - ScopedFastFlag sffs2("LuauGenericFunctions", true); - CheckResult result = check(R"( --!strict table.insert() diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 8dab2605b..c6de0abf5 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -296,7 +296,6 @@ end TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -361,7 +360,6 @@ local c: Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -395,7 +393,6 @@ local d: { a: typeof(c) } TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -434,7 +431,6 @@ type C = Import.Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -456,7 +452,6 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -475,7 +470,6 @@ type E = X<(number, ...string)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -507,7 +501,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -534,7 +527,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -557,10 +549,8 @@ type D = Y TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block CheckResult result = check(R"( type Y = { f: (T...) -> (U...) } @@ -577,7 +567,6 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -599,7 +588,6 @@ type C = Y<(number), boolean> TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 34c25a9fe..9f29b6428 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -400,8 +400,6 @@ local e = a.z TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") { - ScopedFastFlag luauSealedTableUnifyOptionalFix("LuauSealedTableUnifyOptionalFix", true); - CheckResult result = check(R"( local x: { x: number } = { x = 3 } type A = number? @@ -426,4 +424,43 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ +local b: { w: number } = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' +caused by: + Not all union options are compatible. Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 930c1a39b..91efa8188 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -11,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauGenericFunctions); - TEST_SUITE_BEGIN("TypeVarTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 6efa5960e..13be3f942 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -113,6 +113,20 @@ assert(bit32.replace(0, -1, 4) == 2^4) assert(bit32.replace(-1, 0, 31) == 2^31 - 1) assert(bit32.replace(-1, 0, 1, 2) == 2^32 - 7) +-- testing countlz/countrc +assert(bit32.countlz(0) == 32) +assert(bit32.countlz(42) == 26) +assert(bit32.countlz(0xffffffff) == 0) +assert(bit32.countlz(0x80000000) == 0) +assert(bit32.countlz(0x7fffffff) == 1) + +assert(bit32.countrz(0) == 32) +assert(bit32.countrz(1) == 0) +assert(bit32.countrz(42) == 1) +assert(bit32.countrz(0x80000000) == 31) +assert(bit32.countrz(0x40000000) == 30) +assert(bit32.countrz(0x7fffffff) == 0) + --[[ This test verifies a fix in luauF_replace() where if the 4th parameter was not a number, but the first three are numbers, it will @@ -136,5 +150,7 @@ assert(bit32.bxor("1", 3) == 2) assert(bit32.bxor(1, "3") == 2) assert(bit32.btest(1, "3") == true) assert(bit32.btest("1", 3) == true) +assert(bit32.countlz("42") == 26) +assert(bit32.countrz("42") == 1) return('OK') From 8fe0dc0b6d553dc053a43f5ed6607c9c4d1a26c0 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 18:23:34 -0800 Subject: [PATCH 04/14] Fix build --- VM/src/lbitlib.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index c72fe6748..907c43c42 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lcommon.h" #include "lnumutils.h" LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) From 863d3ff6ffa64398a6d13dc8c189bae30fefd557 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 19:42:50 -0800 Subject: [PATCH 05/14] Attempt to work around non-sensical error --- Analysis/src/TypeInfer.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8fad1af91..b27f3e170 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5030,7 +5030,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } - return std::nullopt; + std::optional res = std::nullopt; + return res; }; std::optional ty = resolveLValue(refis, scope, isaP.lvalue); From 3c3541aba84d9209b6098a5c6ae01727ab11ec32 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 20:36:53 -0800 Subject: [PATCH 06/14] Add a comment --- Analysis/src/TypeInfer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b27f3e170..a6696efde 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5030,6 +5030,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } + // local variable works around an odd gcc 9.3 warning: may be used uninitialized std::optional res = std::nullopt; return res; }; From 60e6e86adb5c7a687153989bc471fc05fc4637e8 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 18 Nov 2021 14:21:07 -0800 Subject: [PATCH 07/14] Sync to upstream/release/505 --- Analysis/include/Luau/Documentation.h | 11 +- Analysis/include/Luau/ToString.h | 11 +- Analysis/include/Luau/TypeInfer.h | 24 +- Analysis/include/Luau/TypeVar.h | 87 ++++- Analysis/include/Luau/Unifiable.h | 2 + Analysis/include/Luau/Unifier.h | 1 + Analysis/src/Autocomplete.cpp | 33 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 1 + Analysis/src/Error.cpp | 8 +- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Module.cpp | 16 +- Analysis/src/ToString.cpp | 115 +++++- Analysis/src/Transpiler.cpp | 55 --- Analysis/src/TypeAttach.cpp | 16 + Analysis/src/TypeInfer.cpp | 374 ++++++++++++------- Analysis/src/TypePack.cpp | 1 - Analysis/src/TypeVar.cpp | 50 ++- Analysis/src/Unifier.cpp | 64 +++- Ast/include/Luau/Ast.h | 32 ++ Ast/include/Luau/Parser.h | 1 + Ast/include/Luau/StringUtils.h | 2 + Ast/src/Ast.cpp | 22 ++ Ast/src/Parser.cpp | 68 +++- Ast/src/StringUtils.cpp | 58 +++ CLI/Analyze.cpp | 28 +- CLI/FileUtils.cpp | 46 ++- CLI/FileUtils.h | 3 + CLI/Repl.cpp | 81 ++--- CMakeLists.txt | 17 +- Compiler/include/Luau/Compiler.h | 4 +- Compiler/include/luacode.h | 39 ++ Compiler/src/Compiler.cpp | 46 +-- Compiler/src/lcode.cpp | 29 ++ Makefile | 10 +- Sources.cmake | 3 + VM/include/lua.h | 18 +- VM/include/lualib.h | 6 +- VM/src/lapi.cpp | 10 +- VM/src/lbaselib.cpp | 8 +- VM/src/lbitlib.cpp | 1 + VM/src/lcorolib.cpp | 36 +- VM/src/ldebug.cpp | 10 +- VM/src/ldo.cpp | 6 +- VM/src/linit.cpp | 2 +- VM/src/lstate.cpp | 28 ++ VM/src/lstrlib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- bench/bench.py | 32 +- fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 16 +- tests/Compiler.test.cpp | 63 +--- tests/Conformance.test.cpp | 99 ++--- tests/IostreamOptional.h | 3 +- tests/Module.test.cpp | 6 +- tests/ToString.test.cpp | 111 +++++- tests/TypeInfer.aliases.test.cpp | 2 +- tests/TypeInfer.annotations.test.cpp | 4 +- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.refinements.test.cpp | 3 +- tests/TypeInfer.singletons.test.cpp | 377 ++++++++++++++++++++ tests/TypeInfer.test.cpp | 87 +++-- tests/TypeInfer.tryUnify.test.cpp | 30 +- tests/TypeInfer.unionTypes.test.cpp | 5 +- tests/conformance/coroutine.lua | 54 +++ tests/conformance/debugger.lua | 9 + 65 files changed, 1820 insertions(+), 580 deletions(-) create mode 100644 Compiler/include/luacode.h create mode 100644 Compiler/src/lcode.cpp create mode 100644 tests/TypeInfer.singletons.test.cpp diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7b609b4fa..68ff3a7c2 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -12,10 +12,17 @@ namespace Luau struct FunctionDocumentation; struct TableDocumentation; struct OverloadedFunctionDocumentation; +struct BasicDocumentation; -using Documentation = Luau::Variant; +using Documentation = Luau::Variant; using DocumentationSymbol = std::string; +struct BasicDocumentation +{ + std::string documentation; + std::string learnMoreLink; +}; + struct FunctionParameterDocumentation { std::string name; @@ -29,6 +36,7 @@ struct FunctionDocumentation std::string documentation; std::vector parameters; std::vector returns; + std::string learnMoreLink; }; struct OverloadedFunctionDocumentation @@ -43,6 +51,7 @@ struct TableDocumentation { std::string documentation; Luau::DenseHashMap keys; + std::string learnMoreLink; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index e5683fc40..50379c1cd 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -23,10 +23,11 @@ struct ToStringNameMap struct ToStringOptions { - bool exhaustive = false; // If true, we produce complete output rather than comprehensible output - bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. - bool functionTypeArguments = false; // If true, output function type argument names when they are available - bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression void dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 306ac77d8..78d642c58 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -175,10 +175,10 @@ struct TypeChecker std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors); + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); - ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); @@ -282,6 +282,14 @@ struct TypeChecker // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); + // Produce an "emergency backup type" for recovery from type errors. + // This comes in two flavours, depening on whether or not we can make a good guess + // for an error recovery type. + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(const ScopePtr& scope); + TypePackId errorRecoveryTypePack(const ScopePtr& scope); + private: void prepareErrorsForDisplay(ErrorVec& errVec); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); @@ -297,6 +305,10 @@ struct TypeChecker TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); + // Produce a new singleton type var. + TypeId singletonType(bool value); + TypeId singletonType(std::string value); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -330,8 +342,8 @@ struct TypeChecker const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -347,7 +359,6 @@ struct TypeChecker void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; @@ -387,12 +398,9 @@ struct TypeChecker const TypeId booleanType; const TypeId threadType; const TypeId anyType; - - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6bd7932db..093ea4319 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -108,6 +108,79 @@ struct PrimitiveTypeVar } }; +// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Types for true and false +struct BoolSingleton +{ + bool value; + + bool operator==(const BoolSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const BoolSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// Types for "foo", "bar" etc. +struct StringSingleton +{ + std::string value; + + bool operator==(const StringSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const StringSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// No type for float singletons, partly because === isn't any equalivalence on floats +// (NaN != NaN). + +using SingletonVariant = Luau::Variant; + +struct SingletonTypeVar +{ + explicit SingletonTypeVar(const SingletonVariant& variant) + : variant(variant) + { + } + + explicit SingletonTypeVar(SingletonVariant&& variant) + : variant(std::move(variant)) + { + } + + // Default operator== is C++20. + bool operator==(const SingletonTypeVar& rhs) const + { + return variant == rhs.variant; + } + + bool operator!=(const SingletonTypeVar& rhs) const + { + return !(*this == rhs); + } + + SingletonVariant variant; +}; + +template +const T* get(const SingletonTypeVar* stv) +{ + if (stv) + return get_if(&stv->variant); + else + return nullptr; +} + struct FunctionArgument { Name name; @@ -332,8 +405,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders bool maybeGeneric(const TypeId ty); +// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton +bool maybeSingleton(TypeId ty); + struct SingletonTypes { const TypeId nilType; @@ -418,16 +494,19 @@ struct SingletonTypes const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(); + TypePackId errorRecoveryTypePack(); + private: std::unique_ptr arena; TypeId makeStringMetatable(); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index c2e07e466..b47610fca 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -105,6 +105,8 @@ struct Generic struct Error { + // This constructor has to be public, since it's used in TypeVar and TypePack, + // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); int index; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index be0aadd05..503034a1b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -65,6 +65,7 @@ struct Unifier private: void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyPrimitives(TypeId superTy, TypeId subTy); + void tryUnifySingletons(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1c94bb684..6fc0b3f88 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - unifier.tryUnify(expectedType, actualType); + if (FFlag::LuauAutocompleteAvoidMutation) + { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); + + auto errors = unifier.canUnify(expectedType, actualType); + return errors.empty(); + } + else + { + unifier.tryUnify(expectedType, actualType); - bool ok = unifier.errors.empty(); - unifier.log.rollback(); - return ok; + bool ok = unifier.errors.empty(); + unifier.log.rollback(); + return ok; + } }; auto expr = node->asExpr(); @@ -1496,11 +1510,9 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 96703ef16..9f5c82500 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -153,6 +153,7 @@ declare function gcinfo(): number wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, + close: (thread) -> (boolean, any?) } declare table: { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 46ff2c72a..f80d50a7a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -180,13 +180,13 @@ struct ErrorConverter switch (e.context) { case CountMismatch::Return: - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + - std::to_string(e.actual) + " " + actualVerb + " returned here"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " returned here"; case CountMismatch::Result: // It is alright if right hand side produces more values than the // left hand side accepts. In this context consider only the opposite case. - return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + - std::to_string(e.actual) + " are required here"; + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 2f411274b..1e97705dc 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -458,8 +457,7 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); - if (FFlag::LuauClearScopes) - module->scopes.resize(1); + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 880ffd2e5..32a0646ae 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -161,6 +161,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); void operator()(const MetatableTypeVar& t); @@ -199,7 +200,9 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; } void operator()(const Unifiable::Generic& t) @@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t) { if (encounteredFreeType) *encounteredFreeType = true; - - seenTypes[typeId] = dest.addType(ErrorTypeVar{}); + TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t) { defaultClone(t); } + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); } +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const FunctionTypeVar& t) { TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 885fd489b..735bfa503 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -350,6 +350,23 @@ struct TypeVarStringifier } } + void operator()(TypeId, const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = Luau::get(&stv)) + state.emit(bs->value ? "true" : "false"); + else if (const StringSingleton* ss = Luau::get(&stv)) + { + state.emit("\""); + state.emit(escape(ss->value)); + state.emit("\""); + } + else + { + LUAU_ASSERT(!"Unknown singleton type"); + throw std::runtime_error("Unknown singleton type"); + } + } + void operator()(TypeId, const FunctionTypeVar& ftv) { if (state.hasSeen(&ftv)) @@ -359,6 +376,7 @@ struct TypeVarStringifier return; } + // We should not be respecting opts.hideNamedFunctionTypeParameters here. if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) { state.emit("<"); @@ -514,7 +532,14 @@ struct TypeVarStringifier break; } - state.emit(name); + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } state.emit(": "); stringify(prop.type); comma = true; @@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + std::string s = prefix; + + auto toString_ = [&opts](TypeId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + auto toStringPack_ = [&opts](TypePackId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) + { + s += "<"; + + bool first = true; + for (TypeId g : ftv.generics) + { + if (!first) + s += ", "; + first = false; + s += toString_(g); + } + + for (TypePackId gp : ftv.genericPacks) + { + if (!first) + s += ", "; + first = false; + s += toStringPack_(gp); + } + + s += ">"; + } + + s += "("; + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + s += ", "; + first = false; + + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += toString_(*argPackIter); + + ++argPackIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } + } + + if (argPackIter.tail()) + { + if (auto vtp = get(*argPackIter.tail())) + s += ", ...: " + toString_(vtp->ty); + else + s += ", ...: " + toStringPack_(*argPackIter.tail()); + } + + s += "): "; + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + if (retSize == 0 && !hasTail) + s += "()"; + else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) + s += toStringPack_(ftv.retType); + else + s += "(" + toStringPack_(ftv.retType) + ")"; + + return s; +} + void dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 7d880af49..6627fbe36 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { - -std::string escape(std::string_view s) -{ - std::string r; - r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting - - for (uint8_t c : s) - { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') - r += c; - else - { - r += '\\'; - - switch (c) - { - case '\a': - r += 'a'; - break; - case '\b': - r += 'b'; - break; - case '\f': - r += 'f'; - break; - case '\n': - r += 'n'; - break; - case '\r': - r += 'r'; - break; - case '\t': - r += 't'; - break; - case '\v': - r += 'v'; - break; - case '\'': - r += '\''; - break; - case '\"': - r += '\"'; - break; - case '\\': - r += '\\'; - break; - default: - Luau::formatAppend(r, "%03u", c); - } - } - } - - return r; -} - bool isIdentifierStartChar(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 11aa7b394..af6d2543d 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -96,6 +96,22 @@ class TypeRehydrationVisitor return nullptr; } } + + AstType* operator()(const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = get(&stv)) + return allocator->alloc(Location(), bs->value); + else if (const StringSingleton* ss = get(&stv)) + { + AstArray value; + value.data = const_cast(ss->value.c_str()); + value.size = strlen(value.data); + return allocator->alloc(Location(), value); + } + else + return nullptr; + } + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8fad1af91..b2ae94c72 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) namespace Luau { @@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(singletonTypes.booleanType) , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) - , errorType(singletonTypes.errorType) , optionalNumberType(singletonTypes.optionalNumberType) , anyTypePack(singletonTypes.anyTypePack) - , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = ErrorTypeVar{}; + *asMutable(type) = *errorRecoveryType(anyType); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) else if (auto tail = valueIter.tail()) { if (get(*tail)) - right = errorType; + right = errorRecoveryType(scope); else if (auto vtp = get(*tail)) right = vtp->ty; else if (get(*tail)) @@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorType, forin.location); + unify(var, errorRecoveryType(scope), forin.location); return check(loopScope, *forin.body); } @@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorType; + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) unify(var, varTy, forin.location); @@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; else - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; } else { @@ -1398,7 +1399,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult result; @@ -1407,12 +1408,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = checkExpr(scope, *a->expr); else if (expr.is()) result = {nilType}; - else if (expr.is()) - result = {booleanType}; + else if (const AstExprConstantBool* bexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(bexpr->value)}; + else + result = {booleanType}; + } + else if (const AstExprConstantString* sexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + else + result = {stringType}; + } else if (expr.is()) result = {numberType}; - else if (expr.is()) - result = {stringType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1485,7 +1496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1497,7 +1508,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1509,7 +1520,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa std::vector types = flatten(varargPack).first; return {!types.empty() ? types[0] : nilType}; } - else if (auto ftp = get(varargPack)) + else if (get(varargPack)) { TypeId head = freshType(scope); TypePackId tail = freshTypePack(scope); @@ -1517,14 +1528,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {head}; } if (get(varargPack)) - return {errorType}; + return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorType}; + return {errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1539,7 +1550,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (auto ftp = get(retPack)) + else if (get(retPack)) { TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); @@ -1547,7 +1558,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa return {head, std::move(result.predicates)}; } if (get(retPack)) - return {errorType, std::move(result.predicates)}; + return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) @@ -1572,7 +1583,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - return {errorType}; + return {errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) @@ -1876,6 +1887,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; + const UnionTypeVar* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; @@ -1894,6 +1906,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } + else if (FFlag::LuauExpectedTypesOfProperties) + if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -1916,6 +1931,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } + else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + { + std::vector expectedResultTypes; + for (TypeId expectedOption : expectedUnion) + if (const TableTypeVar* ttv = get(follow(expectedOption))) + if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) + expectedResultTypes.push_back(prop->second.type); + if (expectedResultTypes.size() == 1) + expectedResultType = expectedResultTypes[0]; + else if (expectedResultTypes.size() > 1) + expectedResultType = addType(UnionTypeVar{expectedResultTypes}); + } } else { @@ -1958,21 +1985,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) - return {errorType}; + retType = errorRecoveryType(retType); - return {first(retType).value_or(nilType)}; + return {retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorType}; + return {errorRecoveryType(scope)}; } reportErrors(tryUnify(numberType, operandType, expr.location)); @@ -1984,7 +2012,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) - return {errorType}; + return {errorRecoveryType(scope)}; if (get(operandType)) return {numberType}; // Not strictly correct: metatables permit overriding this @@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b if (unify(a, b, location)) return a; - return errorType; + return errorRecoveryType(anyType); } if (*a == *b) @@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); + // TODO: this check seems odd, the second part is redundant + // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } if (leftMetatable) @@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } } @@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation( { reportError( expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } @@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); - return errorType; + return errorRecoveryType(booleanType); } if (needsMetamethod) { reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } return booleanType; @@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - return errorType; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation( auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); + bool hasErrors = !state.errors.empty(); - if (!state.errors.empty()) - return errorType; + if (FFlag::LuauErrorRecoveryType && hasErrors) + { + // If there are unification errors, the return type may still be unknown + // so we loosen the argument types to see if that helps. + TypePackId fallbackArguments = freshTypePack(scope); + TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + state.log.rollback(); + state.errors.clear(); + state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); + if (!state.errors.empty()) + state.log.rollback(); + } - return first(retType).value_or(nilType); + TypeId retType = first(retTypePack).value_or(nilType); + if (hasErrors) + retType = errorRecoveryType(retType); + + return retType; }; std::string op = opToMetaTableEntry(expr.op); @@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation( reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), toString(lhsType).c_str(), toString(rhsType).c_str())}); - return errorType; + + return errorRecoveryType(scope); } switch (expr.op) @@ -2414,11 +2461,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); if (!errorVec.empty()) - { - reportErrors(errorVec); - return {errorType, std::move(result.predicates)}; - } + annotationType = errorRecoveryType(annotationType); return {annotationType, std::move(result.predicates)}; } @@ -2434,7 +2479,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) @@ -2476,7 +2521,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return std::pair(errorType, nullptr); + return {errorRecoveryType(scope), nullptr}; } else ice("Unexpected AST node in checkLValue", expr.location); @@ -2488,7 +2533,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope return {*ty, nullptr}; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorType, nullptr}; + return {errorRecoveryType(scope), nullptr}; } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) @@ -2545,24 +2590,25 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { Unifier state = mkUnifier(expr.location); state.tryUnify(indexer->indexType, stringType); + TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { state.log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); - return std::pair(errorType, nullptr); + retType = errorRecoveryType(retType); } - return std::pair(indexer->indexResultType, nullptr); + return std::pair(retType, nullptr); } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2571,7 +2617,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); @@ -2585,12 +2631,12 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) @@ -2615,7 +2661,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); } @@ -2626,7 +2672,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } if (value) @@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) if (isNonstrictMode()) return globalScope->bindings[name].typeId; - return errorType; + return errorRecoveryType(scope); } else { @@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) TableTypeVar* ttv = getMutableTableType(lhsType); if (!ttv) { - if (!isTableIntersection(lhsType)) + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - return errorType; + return errorRecoveryType(scope); } // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorType; + return errorRecoveryType(scope); Name name = indexName->index.value; if (ttv->props.count(name)) - return errorType; + return errorRecoveryType(scope); Property& property = ttv->props[name]; @@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) return property.type; } else if (funName.is()) - { - return errorType; - } + return errorRecoveryType(scope); else { ice("Unexpected AST node type", funName.location); @@ -2991,7 +3036,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A else if (expr.is()) { if (!scope->varargPack) - return {addTypePack({addType(ErrorTypeVar())})}; + return {errorRecoveryTypePack(scope)}; return {*scope->varargPack}; } @@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList( if (get(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. - TypeId argTy = errorType; while (paramIter != endIter) { - state.tryUnify(*paramIter, argTy); + state.tryUnify(*paramIter, errorRecoveryType(anyType)); ++paramIter; } return; @@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList( { while (argIter != endIter) { - unify(*argIter, errorType, state.location); + unify(*argIter, errorRecoveryType(scope), state.location); ++argIter; } // For this case, we want the error span to cover every errant extra parameter @@ -3246,7 +3290,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A // For each overload // Compare parameter and argument types // Report any errors (also speculate dot vs colon warnings!) - // If there are no errors, return the resulting return type + // Return the resulting return type (even if there are errors) + // If there are no matching overloads, unify with (a...) -> (b...) and return b... TypeId selfType = nullptr; TypeId functionType = nullptr; @@ -3268,8 +3313,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - functionType = errorType; - actualFunctionType = errorType; + functionType = errorRecoveryType(scope); + actualFunctionType = functionType; } } else @@ -3296,7 +3341,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A TypePackId argPack = argListResult.type; if (get(argPack)) - return ExprResult{errorTypePack}; + return {errorRecoveryTypePack(scope)}; TypePack* args = getMutable(argPack); LUAU_ASSERT(args != nullptr); @@ -3314,19 +3359,34 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector errors; // errors encountered for each overload std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; for (TypeId fn : overloads) { fn = follow(fn); - if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) return {retPack}; - return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + + if (FFlag::LuauErrorRecoveryType) + { + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; + } + + return {errorRecoveryTypePack(retPack)}; } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -3382,7 +3442,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors) + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { fn = stripFromNilAndReport(fn, expr.func->location); @@ -3394,7 +3454,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + return {{errorRecoveryTypePack(scope)}}; } if (get(fn)) @@ -3427,14 +3487,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload( - scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + overloadsThatMatchArgCount, overloadsThatDont, errors); } } reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorTypePack, expr.func->location); - return {{errorTypePack}}; + unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + return {{errorRecoveryTypePack(retPack)}}; } // When this function type has magic functions and did return something, we select that overload instead. @@ -3476,6 +3536,8 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); + else if (FFlag::LuauErrorRecoveryType) + overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); state.log.rollback(); @@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, - TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, - const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors) { if (overloads.size() == 1) { reportErrors(std::get<0>(errors.front())); - return {errorTypePack}; + return; } std::vector overloadTypes = overloadsThatMatchArgCount; @@ -3622,7 +3684,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) - return {errorTypePack}; + return; } std::string s; @@ -3655,7 +3717,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); // No viable overload - return {errorTypePack}; + return; } ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, @@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); - return errorType; + return errorRecoveryType(anyType); } return anyType; @@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module reportError(TypeError{location, UnknownRequire{reportedModulePath}}); } - return errorType; + return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } std::optional moduleType = first(module->getModuleScope()->returnType); @@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } SeenTypes seenTypes; @@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!qty.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (ty == *qty) @@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } } @@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(anyType); } } @@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo else { reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryTypePack(anyTypePack); } } @@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } +TypeId TypeChecker::singletonType(bool value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); +} + +TypeId TypeChecker::singletonType(std::string value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); +} + +TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryType(); +} + +TypeId TypeChecker::errorRecoveryType(TypeId guess) +{ + return singletonTypes.errorRecoveryType(guess); +} + +TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryTypePack(); +} + +TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) +{ + return singletonTypes.errorRecoveryTypePack(guess); +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(anyType); } ToStringOptions opts; @@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { if (lit->name == Parser::errorName) - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); std::string typeName; if (lit->hasPrefix) @@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) @@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } - else if (FFlag::LuauTypeAliasPacks) + + if (FFlag::LuauTypeAliasPacks) { if (!lit->hasParameterList && !tf->typePackParams.empty()) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } std::vector typeParams; @@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - return addType(ErrorTypeVar{}); + + if (FFlag::LuauErrorRecoveryType) + { + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); + } + else + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) @@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation for (const auto& param : lit->parameters) typeParams.push_back(resolveType(scope, *param.type)); + if (FFlag::LuauErrorRecoveryType) + { + // If there aren't enough type parameters, pad them out with error recovery types + // (we've already reported the error) + while (typeParams.size() < lit->parameters.size) + typeParams.push_back(errorRecoveryType(scope)); + } + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { // If the generic parameters and the type arguments are the same, we are about to @@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(IntersectionTypeVar{types}); } - else if (annotation.is()) + else if (const auto& tsb = annotation.as()) + { + return singletonType(tsb->value); + } + else if (const auto& tss = annotation.as()) { - return addType(ErrorTypeVar{}); + return singletonType(std::string(tss->value.data, tss->value.size)); } + else if (annotation.is()) + return errorRecoveryType(scope); else { reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } } @@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return addTypePack(TypePackVar{Unifiable::Error{}}); + return errorRecoveryTypePack(scope); } return *genericTy; @@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (!maybeInstantiated.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); - return errorType; + return errorRecoveryType(scope); } TypeId instantiated = *maybeInstantiated; @@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); @@ -5030,7 +5150,9 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; }; std::optional ty = resolveLValue(refis, scope, isaP.lvalue); @@ -5041,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement addRefinement(refis, isaP.lvalue, *result); else { - addRefinement(refis, isaP.lvalue, errorType); + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); } } @@ -5105,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec addRefinement(refis, typeguardP.lvalue, *result); else { - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); if (sense) errVec.push_back( TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); @@ -5116,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec auto fail = [&](const TypeErrorData& err) { errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; if (!typeguardP.isTypeof) @@ -5137,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); } -void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) -{ - if (!sense) - return; - - static std::vector primitives{ - "string", "number", "boolean", "nil", "thread", - "table", // no op. Requires special handling. - "function", // no op. Requires special handling. - "userdata", // no op. Requires special handling. - }; - - if (auto typeFun = globalScope->lookupType(typeguardP.kind); - typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) - { - if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - else if (typeguardP.isTypeof) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - } -} - void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 228b19267..d3221c732 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp) { return const_cast(tp); } - } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index cd447ca23..924bf082a 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAG(LuauErrorRecoveryType) namespace Luau { @@ -293,7 +294,7 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (get(ty)) return true; else if (auto ttv = get(ty)) { @@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty) return isGeneric(ty); } +bool maybeSingleton(TypeId ty) +{ + ty = follow(ty); + if (get(ty)) + return true; + if (const UnionTypeVar* utv = get(ty)) + for (TypeId option : utv) + if (get(follow(option))) + return true; + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) @@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes() , booleanType(&booleanType_) , threadType(&threadType_) , anyType(&anyType_) - , errorType(&errorType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) - , errorTypePack(&errorTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable() return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } +TypeId SingletonTypes::errorRecoveryType() +{ + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack() +{ + return &errorTypePack_; +} + +TypeId SingletonTypes::errorRecoveryType(TypeId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorTypePack_; +} + SingletonTypes singletonTypes; void persist(TypeId ty) @@ -1141,6 +1178,11 @@ struct QVarFinder return false; } + bool operator()(const SingletonTypeVar&) const + { + return false; + } + bool operator()(const FunctionTypeVar& ftv) const { if (hasGeneric(ftv.argTypes)) @@ -1412,7 +1454,7 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha else if (strchr(options, data[i])) result.push_back(typechecker.numberType); else - result.push_back(typechecker.errorType); + result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 82f621b66..e1a52be4e 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) +LUAU_FASTFLAG(LuauErrorRecoveryType); namespace Luau { @@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(subTy, superTy); + // The occurrence check might have caused superTy no longer to be a free type if (!get(subTy)) { log(subTy); @@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (l && r) { - log(superTy); + if (!FFlag::LuauErrorRecoveryType) + log(superTy); occursCheck(superTy, subTy); r->level = min(r->level, l->level); - *asMutable(superTy) = BoundTypeVar(subTy); + + // The occurrence check might have caused superTy no longer to be a free type + if (!FFlag::LuauErrorRecoveryType) + *asMutable(superTy) = BoundTypeVar(subTy); + else if (!get(superTy)) + { + log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } + return; } else if (l) @@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } + // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { if (auto rightLevel = getMutableLevel(subTy)) @@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); } + return; } else if (r) @@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); + else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) + tryUnifySingletons(superTy, subTy); + else if (get(superTy) && get(subTy)) tryUnifyFunctions(superTy, subTy, isFunctionCall); @@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal { occursCheck(superTp, subTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(superTp)) { log(superTp); *asMutable(superTp) = Unifiable::Bound(subTp); } } - else if (get(subTp)) { occursCheck(subTp, superTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(subTp)) { log(subTp); @@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorType, *superIter); + tryUnify_(singletonTypes.errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorType, *subIter); + tryUnify_(singletonTypes.errorRecoveryType(), *subIter); subIter.advance(); } @@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } +void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +{ + const PrimitiveTypeVar* lp = get(superTy); + const SingletonTypeVar* ls = get(superTy); + const SingletonTypeVar* rs = get(subTy); + + if ((!lp && !ls) || !rs) + ice("passed non singleton/primitive types to unifySingletons"); + + if (ls && *ls == *rs) + return; + + if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + return; + + if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + return; + + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); +} + void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) { FunctionTypeVar* lf = getMutable(superTy); @@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && + lt->state != TableState::Free) { for (const auto& [propName, subProp] : rt->props) { @@ -1038,7 +1079,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) return; } } - + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorType); + tryUnify_(prop.type, singletonTypes.errorRecoveryType()); } else { @@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorType; + const TypeId anyTy = singletonTypes.errorRecoveryType(); if (FFlag::LuauTypecheckOpts) { @@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryType(); return; } @@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); return; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a2189f7b7..5b4bfa033 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -255,6 +255,14 @@ class AstVisitor { return visit((class AstType*)node); } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit((class AstType*)node); + } virtual bool visit(class AstTypeError* node) { return visit((class AstType*)node); @@ -1158,6 +1166,30 @@ class AstTypeError : public AstType unsigned messageIndex; }; +class AstTypeSingletonBool : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonBool) + + AstTypeSingletonBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstTypeSingletonString : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonString) + + AstTypeSingletonString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + const AstArray value; +}; + class AstTypePack : public AstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 39c7d9251..87ebc48b5 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -286,6 +286,7 @@ class Parser // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); + std::optional> parseCharArray(); AstExpr* parseString(); AstLocal* pushLocal(const Binding& binding); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 4f7673fab..6ecf06062 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); +std::string escape(std::string_view s); +bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index b1209faa1..e709894d9 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor) } } +AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray& value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a1bad65ef..bc63e37dc 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau @@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (lexer.current().type == '[') + if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && + (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back({AstName(chars->data), begin.location, type}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } + else if (lexer.current().type == '[') { if (indexer) { @@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + return {allocator.alloc(begin, true)}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + return {allocator.alloc(begin, false)}; + } + else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(begin, svalue)}; + } + else + return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + { + Location location = lexer.current().location; + nextLexeme(); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")}; + } else if (lexer.current().type == Lexeme::Name) { std::optional prefix; @@ -2416,7 +2463,7 @@ AstArray Parser::parseTypeParams() return copy(parameters); } -AstExpr* Parser::parseString() +std::optional> Parser::parseCharArray() { LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); @@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString() { if (!Lexer::fixupQuotedString(scratchData)) { - Location location = lexer.current().location; - nextLexeme(); - - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + return std::nullopt; } } else @@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString() Lexer::fixupMultilineString(scratchData); } - Location start = lexer.current().location; AstArray value = copy(scratchData); - nextLexeme(); + return value; +} - return allocator.alloc(start, value); +AstExpr* Parser::parseString() +{ + Location location = lexer.current().location; + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); } AstLocal* Parser::pushLocal(const Binding& binding) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 24b2283a0..9c7fed316 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size) return hash; } +bool isIdentifier(std::string_view s) +{ + return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); +} + +std::string escape(std::string_view s) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + r += c; + else + { + r += '\\'; + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9ab10aaf0..ebdd78966 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -236,32 +236,12 @@ int main(int argc, char** argv) Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); - int failed = 0; + std::vector files = getSourceFiles(argc, argv); - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; + int failed = 0; - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - // Look for .luau first and if absent, fall back to .lua - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - }); - } - else - { - failed += !analyzeFile(frontend, argv[i], format, annotate); - } - } + for (const std::string& path : files) + failed += !analyzeFile(frontend, path.c_str(), format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index 0702b74f1..b3c9557bb 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -142,6 +142,7 @@ static bool traverseDirectoryRec(const std::string& path, const std::function getParentPath(const std::string& path) return ""; } + +static std::string getExtension(const std::string& path) +{ + std::string::size_type dot = path.find_last_of(".\\/"); + + if (dot == std::string::npos || path[dot] != '.') + return ""; + + return path.substr(dot); +} + +std::vector getSourceFiles(int argc, char** argv) +{ + std::vector files; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + }); + } + else + { + files.push_back(argv[i]); + } + } + + return files; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index f7fbe8afc..da11f512d 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -4,6 +4,7 @@ #include #include #include +#include std::optional readFile(const std::string& name); @@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function getParentPath(const std::string& path); + +std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5c904cca4..b29cd6f9c 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -20,7 +20,7 @@ enum class CompileFormat { - Default, + Text, Binary }; @@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; lua_pushnil(L); @@ -80,7 +80,7 @@ static int lua_require(lua_State* L) // now we can compile & run module on the new thread std::string bytecode = Luau::compile(*source); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { int status = lua_resume(ML, L, 0); @@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source); - if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { size_t len; const char* msg = lua_tolstring(L, -1, &len); @@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL) std::string bytecode = Luau::compile(*source); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { status = lua_resume(L, NULL, 0); } @@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL) status = LUA_ERRSYNTAX; } - if (status == 0) - { - return true; - } - else + if (status != 0) { std::string error; @@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL) error += lua_debugtrace(L); fprintf(stderr, "%s", error.c_str()); - return false; } + + lua_pop(GL, 1); + return status == 0; } static void report(const char* name, const Luau::Location& location, const char* type, const char* message) @@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format) try { Luau::BytecodeBuilder bcb; - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); - bcb.setDumpSource(*source); + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source); switch (format) { - case CompileFormat::Default: + case CompileFormat::Text: printf("%s", bcb.dumpEverything().c_str()); break; case CompileFormat::Binary: @@ -504,7 +506,7 @@ int main(int argc, char** argv) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - CompileFormat format = CompileFormat::Default; + CompileFormat format = CompileFormat::Text; if (strcmp(argv[1], "--compile=binary") == 0) format = CompileFormat::Binary; @@ -514,27 +516,12 @@ int main(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 2; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - failed += !compileFile(name.c_str(), format); - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str(), format); - }); - } - else - { - failed += !compileFile(argv[i], format); - } - } + for (const std::string& path : files) + failed += !compileFile(path.c_str(), format); return failed; } @@ -548,33 +535,25 @@ int main(int argc, char** argv) int profile = 0; for (int i = 1; i < argc; ++i) + { + if (argv[i][0] != '-') + continue; + if (strcmp(argv[i], "--profile") == 0) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + } if (profile) profilerStart(L, profile); - int failed = 0; + std::vector files = getSourceFiles(argc, argv); - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; + int failed = 0; - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !runFile(name.c_str(), L); - }); - } - else - { - failed += !runFile(argv[i], L); - } - } + for (const std::string& path : files) + failed += !runFile(path.c_str(), L); if (profile) { diff --git a/CMakeLists.txt b/CMakeLists.txt index 36014a983..9c69521ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_WERROR "Warnings as errors" OFF) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) @@ -57,11 +58,18 @@ set(LUAU_OPTIONS) if(MSVC) list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. - list(APPEND LUAU_OPTIONS /WX) # Warnings are errors list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores else() list(APPEND LUAU_OPTIONS -Wall) # All warnings - list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors +endif() + +# Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +if(LUAU_WERROR) + if(MSVC) + list(APPEND LUAU_OPTIONS /WX) # Warnings are errors + else() + list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors + endif() endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) @@ -79,7 +87,10 @@ if(LUAU_BUILD_CLI) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) if(UNIX) - target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + endif() endif() if(NOT EMSCRIPTEN) diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 4f88e602e..65e962dac 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,11 +13,9 @@ class AstNameTable; class BytecodeBuilder; class BytecodeEncoder; +// Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { - // default bytecode version target; can be used to compile code for older clients - int bytecodeVersion = 1; - // 0 - no optimization // 1 - baseline optimization level that doesn't prevent debuggability // 2 - includes optimizations that harm debuggability such as inlining diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h new file mode 100644 index 000000000..e235a2e77 --- /dev/null +++ b/Compiler/include/luacode.h @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +/* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUACODE_API +#define LUACODE_API extern +#endif + +typedef struct lua_CompileOptions lua_CompileOptions; + +struct lua_CompileOptions +{ + // 0 - no optimization + // 1 - baseline optimization level that doesn't prevent debuggability + // 2 - includes optimizations that harm debuggability such as inlining + int optimizationLevel; // default=1 + + // 0 - no debugging support + // 1 - line info & function names only; sufficient for backtraces + // 2 - full debug info with local & upvalue names; necessary for debugger + int debugLevel; // default=1 + + // 0 - no code coverage support + // 1 - statement coverage + // 2 - statement and expression coverage (verbose) + int coverageLevel; // default=0 + + // global builtin to construct vectors; disabled by default + const char* vectorLib; + const char* vectorCtor; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals; +}; + +/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */ +LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9712f02f4..5b93c1dc0 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,9 +11,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) -LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; -// TODO: Remove with LuauGenericSpecialGlobals -static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; - CompileError::CompileError(const Location& location, const std::string& message) : location(location) , message(message) @@ -466,7 +460,7 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it @@ -482,18 +476,6 @@ struct Compiler } } } - // Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects - // (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used) - else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - return; - } - } if (!shared) bytecode.emitAD(LOP_NEWCLOSURE, target, pid); @@ -3298,8 +3280,7 @@ struct Compiler bool visit(AstStatLocalFunction* node) override { // record local->function association for some optimizations - if (FFlag::LuauPreloadClosuresUpval) - self->locals[node->name].func = node->func; + self->locals[node->name].func = node->func; return true; } @@ -3711,24 +3692,13 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (FFlag::LuauGenericSpecialGlobals) - { - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; - } - else - { - for (const char* global : kSpecialGlobals) - { - if (AstName name = names.get(global); name.value) + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) compiler.globals[name].writable = true; - } - } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written Compiler::AssignmentVisitor assignmentVisitor(&compiler); @@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) + if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp new file mode 100644 index 000000000..ee150b172 --- /dev/null +++ b/Compiler/src/lcode.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luacode.h" + +#include "Luau/Compiler.h" + +#include + +char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize) +{ + LUAU_ASSERT(outsize); + + Luau::CompileOptions opts; + + if (options) + { + static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match"); + memcpy(static_cast(&opts), options, sizeof(opts)); + } + + std::string result = compile(std::string(source, size), opts); + + char* copy = static_cast(malloc(result.size())); + if (!copy) + return nullptr; + + memcpy(copy, result.data(), result.size()); + *outsize = result.size(); + return copy; +} diff --git a/Makefile b/Makefile index 5d51b3d4e..cab3d43f1 100644 --- a/Makefile +++ b/Makefile @@ -46,14 +46,20 @@ endif OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) # common flags -CXXFLAGS=-g -Wall -Werror +CXXFLAGS=-g -Wall LDFLAGS= -# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused +# some gcc versions treat var in `if (type var = val)` as unused +# some gcc versions treat variables used in constexpr if blocks as unused ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) CXXFLAGS+=-Wno-unused endif +# enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +ifneq ($(werror),) + CXXFLAGS+=-Werror +endif + # configuration-specific flags ifeq ($(config),release) CXXFLAGS+=-O2 -DNDEBUG diff --git a/Sources.cmake b/Sources.cmake index c30cf77d9..23b931c6b 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE Compiler/include/Luau/Bytecode.h Compiler/include/Luau/BytecodeBuilder.h Compiler/include/Luau/Compiler.h + Compiler/include/luacode.h Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/lcode.cpp ) # Luau.Analysis Sources @@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.singletons.test.cpp tests/TypeInfer.tables.test.cpp tests/TypeInfer.test.cpp tests/TypeInfer.tryUnify.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index a9d3e875a..1568d191f 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud); LUA_API void lua_close(lua_State* L); LUA_API lua_State* lua_newthread(lua_State* L); LUA_API lua_State* lua_mainthread(lua_State* L); +LUA_API void lua_resetthread(lua_State* L); +LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation @@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); -LUA_API void lua_pushcfunction( - lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL); +LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); @@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx); LUA_API void lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); -LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); +LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); -LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); +LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx); /* ** `load' and `call' functions (load and run Luau bytecode) */ -LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); +LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env); LUA_API void lua_call(lua_State* L, int nargs, int nresults); LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); @@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) +#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) +#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) @@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n); LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); -LUA_API void lua_singlestep(lua_State* L, bool singlestep); -LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); +LUA_API void lua_singlestep(lua_State* L, int enabled); +LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); @@ -361,6 +364,7 @@ struct lua_Callbacks void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ }; +typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 30cffaff8..fa836955c 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -8,11 +8,12 @@ #define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) #define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) -typedef struct luaL_Reg +struct luaL_Reg { const char* name; lua_CFunction func; -} luaL_Reg; +}; +typedef struct luaL_Reg luaL_Reg; LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); @@ -75,6 +76,7 @@ struct luaL_Buffer struct TString* storage; char buffer[LUA_BUFFERSIZE]; }; +typedef struct luaL_Buffer luaL_Buffer; // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 7e742644f..a79ba0d40 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) return ret; } -void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) +void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) { luaC_checkGC(L); luaC_checkthreadsleep(L); @@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec) return; } -void lua_setreadonly(lua_State* L, int objindex, bool value) +void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); - t->readonly = value; + t->readonly = bool(enabled); return; } @@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex) return res; } -void lua_setsafeenv(lua_State* L, int objindex, bool value) +void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); - t->safeenv = value; + t->safeenv = bool(enabled); return; } diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 87fc16318..61798e2bc 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = { static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) { - lua_pushcfunction(L, u); - lua_pushcfunction(L, f, name, 1); + lua_pushcfunction(L, u, NULL); + lua_pushcclosure(L, f, name, 1); lua_setfield(L, -2, name); } @@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L) auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "pairs", luaB_pairs, luaB_next); - lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); + lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont); lua_setfield(L, -2, "pcall"); - lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); + lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); lua_setfield(L, -2, "xpcall"); return 1; diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index c72fe6748..907c43c42 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lcommon.h" #include "lnumutils.h" LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 9724c0e72..0178fae84 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,6 +5,8 @@ #include "lstate.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) + #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -208,8 +210,7 @@ static int cowrap(lua_State* L) { cocreate(L); - lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); - + lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } @@ -232,6 +233,34 @@ static int coyieldable(lua_State* L) return 1; } +static int coclose(lua_State* L) +{ + if (!FFlag::LuauCoroutineClose) + luaL_error(L, "coroutine.close is not enabled"); + + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + + int status = auxstatus(L, co); + if (status != CO_DEAD && status != CO_SUS) + luaL_error(L, "cannot close %s coroutine", statnames[status]); + + if (co->status == LUA_OK || co->status == LUA_YIELD) + { + lua_pushboolean(L, true); + lua_resetthread(co); + return 1; + } + else + { + lua_pushboolean(L, false); + if (lua_gettop(co)) + lua_xmove(co, L, 1); /* move error message */ + lua_resetthread(co); + return 2; + } +} + static const luaL_Reg co_funcs[] = { {"create", cocreate}, {"running", corunning}, @@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = { {"wrap", cowrap}, {"yield", coyield}, {"isyieldable", coyieldable}, + {"close", coclose}, {NULL, NULL}, }; @@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); + lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 1890e6823..d77f84ef9 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); } - uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); + uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]); // patch just the opcode byte, leave arguments alone p->code[i] &= ~0xff; @@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc) return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; } -void lua_singlestep(lua_State* L, bool singlestep) +void lua_singlestep(lua_State* L, int enabled) { - L->singlestep = singlestep; + L->singlestep = bool(enabled); } -void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) +void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) { const TValue* func = luaA_toobject(L, funcindex); api_check(L, ttisfunction(func) && !clvalue(func)->isC); - luaG_breakpoint(L, clvalue(func)->l.p, line, enable); + luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 328b47e69..1259d4619 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) +LUAU_FASTFLAG(LuauCoroutineClose) /* ** {====================================================== @@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud) if (L->status == 0) { // start coroutine - LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); + LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); + if (FFlag::LuauCoroutineClose && firstArg == L->base) + luaG_runerror(L, "cannot resume dead coroutine"); + if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) return; diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index bf5e738f9..4e40165ab 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) const luaL_Reg* lib = lualibs; for (; lib->func; lib++) { - lua_pushcfunction(L, lib->func); + lua_pushcfunction(L, lib->func, NULL); lua_pushstring(L, lib->name); lua_call(L, 1, 0); } diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 0b2dfb692..24e970635 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1) luaM_free(L, L1, sizeof(lua_State), L1->memcat); } +void lua_resetthread(lua_State* L) +{ + /* close upvalues before clearing anything */ + luaF_close(L, L->stack); + /* clear call frames */ + CallInfo* ci = L->base_ci; + ci->func = L->stack; + ci->base = ci->func + 1; + ci->top = ci->base + LUA_MINSTACK; + setnilvalue(ci->func); + L->ci = ci; + luaD_reallocCI(L, BASIC_CI_SIZE); + /* clear thread state */ + L->status = LUA_OK; + L->base = L->ci->base; + L->top = L->ci->base; + L->nCcalls = L->baseCcalls = 0; + /* clear thread stack */ + luaD_reallocstack(L, BASIC_STACK_SIZE); + for (int i = 0; i < L->stacksize; i++) + setnilvalue(L->stack + i); +} + +int lua_isthreadreset(lua_State* L) +{ + return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK; +} + lua_State* lua_newstate(lua_Alloc f, void* ud) { int i; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 80a34483a..b576f8093 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -748,7 +748,7 @@ static int gmatch(lua_State* L) luaL_checkstring(L, 2); lua_settop(L, 2); lua_pushinteger(L, 0); - lua_pushcfunction(L, gmatch_aux, NULL, 3); + lua_pushcclosure(L, gmatch_aux, NULL, 3); return 1; } diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 6a0262962..378de3d0d 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -265,7 +265,7 @@ static int iter_aux(lua_State* L) static int iter_codes(lua_State* L) { luaL_checkstring(L, 1); - lua_pushcfunction(L, iter_aux); + lua_pushcfunction(L, iter_aux, NULL); lua_pushvalue(L, 1); lua_pushinteger(L, 0); return 3; diff --git a/bench/bench.py b/bench/bench.py index b23ca8913..39f219f31 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -25,8 +25,8 @@ import scipy from scipy import stats except ModuleNotFoundError: - print("scipy package is required") - exit(1) + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None scriptdir = os.path.dirname(os.path.realpath(__file__)) defaultVm = 'luau.exe' if os.name == "nt" else './luau' @@ -200,11 +200,14 @@ def finalizeResult(result): result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) result.unbiasedEst = result.sampleStdDev * result.sampleStdDev - # Two-tailed distribution with 95% conf. - tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) + if stats: + # Two-tailed distribution with 95% conf. + tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) - # Compute confidence interval - result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + # Compute confidence interval + result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + else: + result.sampleConfidenceInterval = result.sampleStdDev else: result.sampleStdDev = 0 result.unbiasedEst = 0 @@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons): tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) degreesOfFreedom = 2 * main.count - 2 - # Two-tailed distribution with 95% conf. - tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - - noSignificantDifference = tStat < tCritical + if stats: + # Two-tailed distribution with 95% conf. + tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + noSignificantDifference = tStat < tCritical + pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + else: + noSignificantDifference = None + pValue = -1 - if noSignificantDifference: + if noSignificantDifference is None: + verdict = "" + elif noSignificantDifference: verdict = "likely same" elif main.avg < compare.avg: verdict = "likely worse" diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c85fac7d3..ae2399e49 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) lua_State* L = lua_newthread(globalState); luaL_sandboxthread(L); - if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8a7798f3a..5a7c86023 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl { }; -struct UnfrozenACFixture : ACFixtureImpl -{ -}; - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -1919,9 +1915,10 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end @@ -1950,9 +1947,10 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function a(x: boolean) end local function b(x: number?) end @@ -2484,7 +2482,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7f03019c3..4ce8d08ae 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -11,8 +11,6 @@ #include LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauPreloadClosuresFenv) -LUAU_FASTFLAG(LuauPreloadClosuresUpval) LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( @@ -3479,15 +3477,13 @@ CAPTURE VAL R0 RETURN R1 1 )"); - if (FFlag::LuauPreloadClosuresFenv) - { - // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( setfenv(1, {}) return function() print("hi") end )", - 1), - R"( + 1), + R"( GETIMPORT R0 1 LOADN R1 1 NEWTABLE R2 0 0 @@ -3496,23 +3492,21 @@ NEWCLOSURE R0 P0 RETURN R0 1 )"); - // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 RETURN R0 1 )"); - } } TEST_CASE("SharedClosure") { ScopedFastFlag sff1("LuauPreloadClosures", true); - ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( @@ -3671,7 +3665,7 @@ RETURN R0 0 )"); } -TEST_CASE("LuauGenericSpecialGlobals") +TEST_CASE("MutableGlobals") { const char* source = R"( print() @@ -3685,43 +3679,6 @@ shared.print() workspace.print() )"; - { - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; - - // Check Roblox globals are here - CHECK_EQ("\n" + compileFunction0(source), R"( -GETIMPORT R0 1 -CALL R0 0 0 -GETIMPORT R1 3 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 5 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 9 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 11 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 13 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 15 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 17 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -RETURN R0 0 -)"); - } - - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; - // Check Roblox globals are no longer here CHECK_EQ("\n" + compileFunction0(source), R"( GETIMPORT R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c1b790b9c..e495a2136 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Compiler.h" +#include "lua.h" +#include "lualib.h" +#include "luacode.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" @@ -10,9 +12,6 @@ #include "doctest.h" #include "ScopedFlags.h" -#include "lua.h" -#include "lualib.h" - #include #include @@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + size_t bytecodeSize = 0; + char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize); + int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); + free(bytecode); + + if (result == 0) return 1; lua_pushnil(L); @@ -179,21 +182,17 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - Luau::CompileOptions copts; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; // default copts.debugLevel = 2; // for debugger tests copts.vectorCtor = "vector"; // for vector tests - std::string bytecode = Luau::compile(source, copts); - int status = 0; + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); + free(bytecode); - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) - { - status = lua_resume(L, nullptr, 0); - } - else - { - status = LUA_ERRSYNTAX; - } + int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; while (yield && (status == LUA_YIELD || status == LUA_BREAK)) { @@ -332,53 +331,61 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { + ScopedFastFlag sff("LuauCoroutineClose", true); + runConformance("coroutine.lua"); } -TEST_CASE("PCall") +static int cxxthrow(lua_State* L) { - runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) -> int { #if LUA_USE_LONGJMP - luaL_error(L, "oops"); + luaL_error(L, "oops"); #else - throw std::runtime_error("oops"); + throw std::runtime_error("oops"); #endif - }); +} + +TEST_CASE("PCall") +{ + runConformance("pcall.lua", [](lua_State* L) { + lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); - lua_pushcfunction(L, [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); lua_setglobal(L, "resumeerror"); }); } TEST_CASE("Pack") { - ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; - + ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; + runConformance("tpack.lua"); } TEST_CASE("Vector") { runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector); + lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); lua_pushvector(L, 0.0f, 0.0f, 0.0f); luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index); + lua_pushcfunction(L, lua_vector_index, nullptr); lua_settable(L, -3); lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall); + lua_pushcfunction(L, lua_vector_namecall, nullptr); lua_settable(L, -3); lua_setreadonly(L, -1, true); @@ -513,15 +520,19 @@ TEST_CASE("Debugger") }; // add breakpoint() function - lua_pushcfunction(L, [](lua_State* L) -> int { - int line = luaL_checkinteger(L, 1); - - lua_Debug ar = {}; - lua_getinfo(L, 1, "f", &ar); - - lua_breakpoint(L, -1, line, true); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; + + lua_Debug ar = {}; + lua_getinfo(L, 1, "f", &ar); + + lua_breakpoint(L, -1, line, enabled); + return 0; + }, + "breakpoint"); lua_setglobal(L, "breakpoint"); }, [](lua_State* L) { @@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject") if (nsize == 0) { free(ptr); - return NULL; + return nullptr; } else if (nsize > 512 * 1024) { diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index e55b5b0c3..e0756badd 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -4,7 +4,8 @@ #include #include -namespace std { +namespace std +{ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 18f55d2c1..7a3543c7c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") bool encounteredFreeType = false; TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTy)); + CHECK_EQ("any", toString(clonedTy)); CHECK(encounteredFreeType); encounteredFreeType = false; TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTp)); + CHECK_EQ("...any", toString(clonedTp)); CHECK(encounteredFreeType); } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b076e9ad7..80a258f5b 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TypeVar tv{ttv}; - ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + ToStringOptions o; + o.maxTableLength = 40; CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } @@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); + CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") CHECK_EQ(toString(tableTy), "Table
"); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") +{ + CheckResult result = check(R"( + local function id(x) return x end + )"); + + TypeId ty = requireType("id"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") +{ + CheckResult result = check(R"( + local function map(arr, fn) + local t = {} + for i = 0, #arr do + t[i] = fn(arr[i]) + end + return t + end + )"); + + TypeId ty = requireType("map"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); +} + +TEST_CASE("toStringNamedFunction_unit_f") +{ + TypePackVar empty{TypePack{}}; + FunctionTypeVar ftv{&empty, &empty, {}, false}; + CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") +{ + CheckResult result = check(R"( + local function f(x: a, ...): (a, a, b...) + return x, x, ... + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") +{ + CheckResult result = check(R"( + local function f(): ...number + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") +{ + CheckResult result = check(R"( + local function f(): (string, ...number) + return 'a', 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") +{ + CheckResult result = check(R"( + local f: (number, y: number) -> number + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") +{ + CheckResult result = check(R"( + local function f(x: T, g: (T) -> U)): () + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + ToStringOptions opts; + opts.hideNamedFunctionTypeParameters = true; + CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index c27f8083b..74ce155c2 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ CheckResult result = check(R"( type A = number type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + local foo: string = 1 -- "Type 'number' could not be converted into 'string'" )"); LUAU_REQUIRE_ERROR_COUNT(2, result); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2e4001641..091c2f012 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( type A = B type B = A @@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const ErrorTypeVar* ftv = get(follow(fType)); + const AnyTypeVar* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de2f01544..88c2dc85d 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") @@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") @@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 733fc39b3..fe8e7ff90 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -704,9 +704,10 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); else CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) + if (FFlag::LuauQuantifyInPlace2) CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" else CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp new file mode 100644 index 000000000..5f95efd52 --- /dev/null +++ b/tests/TypeInfer.singletons.test.cpp @@ -0,0 +1,377 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeSingletons"); + +TEST_CASE_FIXTURE(Fixture, "bool_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: false = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: "bar" = "bar" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = false + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "bar" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "\n" = "\000\r" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: boolean = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, "foo") + g(false, 37) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, 37) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "foo" + local b : MyEnum = "bar" + local c : MyEnum = "baz" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "bang" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum1 = "foo" | "bar" + type MyEnum2 = MyEnum1 | "baz" + local a : MyEnum1 = "foo" + local b : MyEnum2 = a + local c : string = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExpectedTypesOfProperties", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Dog = { tag = "Dog", howls = true } + local b : Animal = { tag = "Cat", meows = true } + local c : Animal = a + c = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", howls = true } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", meows = true } + a.tag = "Dog" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["foo"] : number, + ["$$bar"] : string, + baz : boolean + } + local t: T = { + ["foo"] = 37, + ["$$bar"] = "hi", + baz = true + } + local a: number = t.foo + local b: string = t["$$bar"] + local c: boolean = t.baz + t.foo = 5 + t["$$bar"] = "lo" + t.baz = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["$$bar"] : string, + } + local t: T = { + ["$$bar"] = "hi", + } + t["$$bar"] = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type S = "bar" + type T = { + [("foo")] : number, + [S] : string, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + local x: { ["<>"] : number } + x = { ["\n"] = 5 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 30d9130a5..99fd8339c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ(*p, *typeChecker.errorType); + CHECK_EQ("*unknown*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") @@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") @@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK(get(requireType("t"))); + CHECK_EQ("*unknown*", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); + CHECK_EQ("*unknown*", toString(hootyType)); } TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") @@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( local err = (true).x @@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK(nullptr != get(requireType("c"))); - CHECK(nullptr != get(requireType("d"))); - CHECK(nullptr != get(requireType("e"))); - CHECK(nullptr != get(requireType("f"))); + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") @@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK(nullptr != get(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - TypeId aType = requireType("a"); - - REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") @@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "compare_numbers") @@ -2901,6 +2903,8 @@ end TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) @@ -2908,7 +2912,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") @@ -2920,7 +2924,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") @@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") print(a) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); } @@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") CHECK_EQ(toString(*it), "(number) -> number"); } +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation @@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") } } +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 2d697fc97..9f9a007f1 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId bType = requireType("b"); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + function f(arg: number) return arg end + local a + local b + local c = f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") @@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") -{ - TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); - - state.tryUnify(variadicPack, errorPack); - REQUIRE_EQ(0, state.errors.size()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 9f29b6428..48496b895 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - TypeId r = requireType("r"); - CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); + CHECK_EQ("*unknown*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -283,7 +282,7 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "optional_union_follow") +TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") { CheckResult result = check(R"( local y: number? = 2 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 753296422..4d9b12953 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -319,4 +319,58 @@ for i=0,30 do assert(#T2 == 1 or T2[#T2] == 42) end +-- test coroutine.close +do + -- ok to close a dead coroutine + local co = coroutine.create(type) + assert(coroutine.resume(co, "testing 'coroutine.close'")) + assert(coroutine.status(co) == "dead") + local st, msg = coroutine.close(co) + assert(st and msg == nil) + -- also ok to close it again + st, msg = coroutine.close(co) + assert(st and msg == nil) + + + -- cannot close the running coroutine + coroutine.wrap(function() + local st, msg = pcall(coroutine.close, coroutine.running()) + assert(not st and string.find(msg, "running")) + end)() + + -- cannot close a "normal" coroutine + coroutine.wrap(function() + local co = coroutine.running() + coroutine.wrap(function () + local st, msg = pcall(coroutine.close, co) + assert(not st and string.find(msg, "normal")) + end)() + end)() + + -- closing a coroutine after an error + local co = coroutine.create(error) + local obj = {42} + local st, msg = coroutine.resume(co, obj) + assert(not st and msg == obj) + st, msg = coroutine.close(co) + assert(not st and msg == obj) + -- after closing, no more errors + st, msg = coroutine.close(co) + assert(st and msg == nil) + + -- closing a coroutine that has outstanding upvalues + local f + local co = coroutine.create(function() + local a = 42 + f = function() return a end + coroutine.yield() + a = 20 + end) + coroutine.resume(co) + assert(f() == 42) + st, msg = coroutine.close(co) + assert(st and msg == nil) + assert(f() == 42) +end + return'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 5e69fc6bc..6ba99fb9b 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad() local co = coroutine.create(corobad) assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! +function bar() + print("in bar") +end + +breakpoint(49) +breakpoint(49, false) -- validate that disabling breakpoints works + +bar() + return 'OK' From eed18acec8677b380fdbb7f424d79e1c7dae4273 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 2 Dec 2021 15:20:08 -0800 Subject: [PATCH 08/14] Sync to upstream/release/506 --- Analysis/include/Luau/FileResolver.h | 23 - Analysis/include/Luau/Module.h | 12 +- Analysis/include/Luau/ToDot.h | 31 ++ Analysis/include/Luau/TxnLog.h | 9 - Analysis/include/Luau/TypeInfer.h | 1 - Analysis/include/Luau/TypeVar.h | 15 - Analysis/include/Luau/Unifier.h | 25 +- Analysis/include/Luau/UnifierSharedState.h | 8 + Analysis/include/Luau/VisitTypeVar.h | 4 +- Analysis/src/Autocomplete.cpp | 64 ++- Analysis/src/BuiltinDefinitions.cpp | 6 +- Analysis/src/Error.cpp | 118 ++--- Analysis/src/Frontend.cpp | 27 +- Analysis/src/IostreamHelpers.cpp | 21 +- Analysis/src/JsonEncoder.cpp | 9 +- Analysis/src/Module.cpp | 132 ++--- Analysis/src/Quantify.cpp | 13 +- Analysis/src/RequireTracer.cpp | 195 +------ Analysis/src/Substitution.cpp | 29 +- Analysis/src/ToDot.cpp | 378 ++++++++++++++ Analysis/src/ToString.cpp | 162 +++--- Analysis/src/Transpiler.cpp | 15 +- Analysis/src/TxnLog.cpp | 37 +- Analysis/src/TypeAttach.cpp | 52 +- Analysis/src/TypeInfer.cpp | 325 +++++------- Analysis/src/TypeVar.cpp | 364 -------------- Analysis/src/Unifier.cpp | 306 ++--------- Ast/src/Parser.cpp | 78 +-- CLI/Analyze.cpp | 23 +- CLI/Repl.cpp | 39 -- CLI/Web.cpp | 106 ++++ CMakeLists.txt | 57 ++- Compiler/src/Compiler.cpp | 14 +- Makefile | 6 +- Sources.cmake | 10 + VM/include/lua.h | 11 +- VM/include/luaconf.h | 38 ++ VM/include/lualib.h | 6 + VM/src/lapi.cpp | 65 ++- VM/src/laux.cpp | 95 ++-- VM/src/lbaselib.cpp | 4 +- VM/src/lbitlib.cpp | 2 +- VM/src/lbuiltins.cpp | 12 +- VM/src/lcorolib.cpp | 2 +- VM/src/ldblib.cpp | 2 +- VM/src/ldo.cpp | 71 +-- VM/src/lgc.cpp | 548 +------------------- VM/src/lgcdebug.cpp | 558 +++++++++++++++++++++ VM/src/linit.cpp | 10 +- VM/src/lmathlib.cpp | 5 +- VM/src/lmem.cpp | 8 +- VM/src/lnumutils.h | 8 + VM/src/lobject.cpp | 2 +- VM/src/lobject.h | 23 +- VM/src/loslib.cpp | 2 +- VM/src/lstrlib.cpp | 2 +- VM/src/ltable.cpp | 19 +- VM/src/ltablib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- VM/src/lvmexecute.cpp | 30 +- VM/src/lvmload.cpp | 6 +- VM/src/lvmutils.cpp | 18 +- bench/gc/test_LB_mandel.lua | 2 +- bench/tests/shootout/mandel.lua | 2 +- bench/tests/shootout/qt.lua | 10 +- fuzz/proto.cpp | 4 +- tests/AstQuery.test.cpp | 23 + tests/Autocomplete.test.cpp | 40 +- tests/Compiler.test.cpp | 168 +++++++ tests/Conformance.test.cpp | 114 +++-- tests/Fixture.cpp | 38 +- tests/Fixture.h | 5 +- tests/Frontend.test.cpp | 17 - tests/Linter.test.cpp | 16 + tests/Module.test.cpp | 66 ++- tests/Parser.test.cpp | 2 - tests/ToDot.test.cpp | 366 ++++++++++++++ tests/Transpiler.test.cpp | 3 - tests/TypeInfer.aliases.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 21 +- tests/TypeInfer.provisional.test.cpp | 2 - tests/TypeInfer.tables.test.cpp | 70 +++ tests/TypeInfer.test.cpp | 20 + tests/TypeInfer.typePacks.cpp | 33 -- tests/TypeVar.test.cpp | 45 ++ tests/conformance/apicalls.lua | 8 +- tests/conformance/basic.lua | 5 +- tests/conformance/closure.lua | 8 +- tests/conformance/constructs.lua | 2 +- tests/conformance/coroutine.lua | 2 +- tests/conformance/datetime.lua | 2 +- tests/conformance/debug.lua | 2 +- tests/conformance/errors.lua | 32 +- tests/conformance/gc.lua | 8 +- tests/conformance/nextvar.lua | 6 +- tests/conformance/pcall.lua | 2 +- tests/conformance/utf8.lua | 2 +- tests/conformance/vector.lua | 31 +- tools/svg.py | 9 +- 99 files changed, 2895 insertions(+), 2558 deletions(-) create mode 100644 Analysis/include/Luau/ToDot.h create mode 100644 Analysis/src/ToDot.cpp create mode 100644 CLI/Web.cpp create mode 100644 VM/src/lgcdebug.cpp create mode 100644 tests/ToDot.test.cpp diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 9b74fc12d..0fdcce161 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -51,13 +51,6 @@ struct FileResolver { return std::nullopt; } - - // DEPRECATED APIS - // These are going to be removed with LuauNewRequireTrace2 - virtual bool moduleExists(const ModuleName& name) const = 0; - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -66,22 +59,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - bool moduleExists(const ModuleName& name) const override - { - return false; - } - std::optional fromAstFragment(AstExpr* expr) const override - { - return std::nullopt; - } - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs; - } - std::optional getParentModuleName(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index d08448351..2e41674bf 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -78,9 +78,15 @@ void unfreeze(TypeArena& arena); using SeenTypes = std::unordered_map; using SeenTypePacks = std::unordered_map; -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); +struct CloneState +{ + int recursionCount = 0; + bool encounteredFreeType = false; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); struct Module { diff --git a/Analysis/include/Luau/ToDot.h b/Analysis/include/Luau/ToDot.h new file mode 100644 index 000000000..ce518d3ae --- /dev/null +++ b/Analysis/include/Luau/ToDot.h @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct ToDotOptions +{ + bool showPointers = true; // Show pointer value in the node label + bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes +}; + +std::string toDot(TypeId ty, const ToDotOptions& opts); +std::string toDot(TypePackId tp, const ToDotOptions& opts); + +std::string toDot(TypeId ty); +std::string toDot(TypePackId tp); + +void dumpDot(TypeId ty); +void dumpDot(TypePackId tp); + +} // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 322abd198..29988a3b9 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -25,15 +25,6 @@ struct TxnLog { } - explicit TxnLog(const std::vector>& ownedSeen) - : originalSeenSize(ownedSeen.size()) - , ownedSeen(ownedSeen) - , sharedSeen(nullptr) - { - // This is deprecated! - LUAU_ASSERT(!FFlag::LuauShareTxnSeen); - } - TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 78d642c58..9f553bc14 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -297,7 +297,6 @@ struct TypeChecker private: Unifier mkUnifier(const Location& location); - Unifier mkUnifier(const std::vector>& seen, const Location& location); // These functions are only safe to call when we are in the process of typechecking a module. diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 093ea4319..8c4c2f34f 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -517,21 +517,6 @@ extern SingletonTypes singletonTypes; void persist(TypeId ty); void persist(TypePackId tp); -struct ToDotOptions -{ - bool showPointers = true; // Show pointer value in the node label - bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes -}; - -std::string toDot(TypeId ty, const ToDotOptions& opts); -std::string toDot(TypePackId tp, const ToDotOptions& opts); - -std::string toDot(TypeId ty); -std::string toDot(TypePackId tp); - -void dumpDot(TypeId ty); -void dumpDot(TypePackId tp); - const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 503034a1b..4588cdd8c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -19,12 +19,6 @@ enum Variance Invariant }; -struct UnifierCounters -{ - int recursionCount = 0; - int iterationCount = 0; -}; - struct Unifier { TypeArena* const types; @@ -37,20 +31,11 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - UnifierCounters* counters; - UnifierCounters countersData; - - std::shared_ptr counters_DEPRECATED; - UnifierSharedState& sharedState; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, - UnifierCounters* counters = nullptr); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, - UnifierCounters* counters = nullptr); + Variance variance, UnifierSharedState& sharedState); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -92,9 +77,9 @@ struct Unifier public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); + void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); + void occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -106,10 +91,6 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - - // Remove with FFlagLuauCacheUnifyTableResults - DenseHashSet tempSeenTy_DEPRECATED{nullptr}; - DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index f252a004b..88997c41a 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -24,6 +24,12 @@ struct TypeIdPairHash } }; +struct UnifierCounters +{ + int recursionCount = 0; + int iterationCount = 0; +}; + struct UnifierSharedState { UnifierSharedState(InternalErrorReporter* iceHandler) @@ -39,6 +45,8 @@ struct UnifierSharedState DenseHashSet tempSeenTy{nullptr}; DenseHashSet tempSeenTp{nullptr}; + + UnifierCounters counters; }; } // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index a866655c9..740854b33 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -5,8 +5,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauCacheUnifyTableResults) - namespace Luau { @@ -101,7 +99,7 @@ void visit(TypeId ty, F& f, Set& seen) // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) + if (ttv->boundTo) { visit(*ttv->boundTo, f, seen); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 6fc0b3f88..db2d1d0e5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,9 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -203,8 +203,9 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { SeenTypes seenTypes; SeenTypePacks seenTypePacks; - expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); - actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); + CloneState cloneState; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, cloneState); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, cloneState); auto errors = unifier.canUnify(expectedType, actualType); return errors.empty(); @@ -229,28 +230,51 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*it); - if (canUnify(expectedType, ty)) - return TypeCorrectKind::Correct; + if (FFlag::LuauAutocompletePreferToCallFunctions) + { + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) + { + auto [retHead, retTail] = flatten(ftv->retType); - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); + if (!retHead.empty() && canUnify(expectedType, retHead.front())) + return TypeCorrectKind::CorrectFunctionResult; - if (!ftv) - return TypeCorrectKind::None; + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(expectedType, vtp->ty)) + return TypeCorrectKind::CorrectFunctionResult; + } + } + + return canUnify(expectedType, ty) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + } + else + { + if (canUnify(expectedType, ty)) + return TypeCorrectKind::Correct; - auto [retHead, retTail] = flatten(ftv->retType); + // We also want to suggest functions that return compatible result + const FunctionTypeVar* ftv = get(ty); - if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (!ftv) + return TypeCorrectKind::None; - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - } + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty()) + return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - return TypeCorrectKind::None; + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail))) + return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + } + + return TypeCorrectKind::None; + } } enum class PropIndexType @@ -1413,7 +1437,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (AstStatIf* statIf = node->as(); FFlag::ElseElseIfCompletionImprovements && statIf && !statIf->hasElse) + else if (AstStatIf* statIf = node->as(); statIf && !statIf->hasElse) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 62a06a3cd..bac94a2bd 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAG(LuauNewRequireTrace2) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -473,9 +471,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0]; - - if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index f80d50a7a..8334bd626 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,57 +7,14 @@ #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - -static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) -{ - std::string s = "expects " + std::to_string(expectedCount) + " "; - - if (isTypeArgs) - s += "type "; - - s += "argument"; - if (expectedCount != 1) - s += "s"; - - s += ", but "; - - if (actualCount == 0) - { - s += "none"; - } - else - { - if (actualCount < expectedCount) - s += "only "; - - s += std::to_string(actualCount); - } - - s += (actualCount == 1) ? " is" : " are"; - - s += " specified"; - - return s; -} - static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { - std::string s; + std::string s = "expects "; - if (FFlag::LuauTypeAliasPacks) - { - s = "expects "; + if (isVariadic) + s += "at least "; - if (isVariadic) - s += "at least "; - - s += std::to_string(expectedCount) + " "; - } - else - { - s = "expects " + std::to_string(expectedCount) + " "; - } + s += std::to_string(expectedCount) + " "; if (argPrefix) s += std::string(argPrefix) + " "; @@ -188,10 +145,7 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - if (FFlag::LuauTypeAliasPacks) - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); - else - return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -232,7 +186,7 @@ struct ErrorConverter std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) + if (!e.typeFun.typeParams.empty() || !e.typeFun.typePackParams.empty()) { name += "<"; bool first = true; @@ -246,36 +200,25 @@ struct ErrorConverter name += toString(t); } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId t : e.typeFun.typePackParams) { - for (TypePackId t : e.typeFun.typePackParams) - { - if (first) - first = false; - else - name += ", "; - - name += toString(t); - } + if (first) + first = false; + else + name += ", "; + + name += toString(t); } name += ">"; } - if (FFlag::LuauTypeAliasPacks) - { - if (e.typeFun.typeParams.size() != e.actualParameters) - return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); - + if (e.typeFun.typeParams.size() != e.actualParameters) return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); - } - else - { - return "Generic type '" + name + "' " + - wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); - } + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); } std::string operator()(const Luau::SyntaxError& e) const @@ -591,11 +534,8 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; - if (FFlag::LuauTypeAliasPacks) - { - if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) - return false; - } + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { @@ -603,13 +543,10 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC return false; } - if (FFlag::LuauTypeAliasPacks) + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) - { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) - return false; - } + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; } return true; @@ -733,14 +670,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) +void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); + return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks); + copyError(e, destArena, seenTypes, seenTypePacks, cloneState); }; if constexpr (false) @@ -864,9 +801,10 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) { SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks); + copyError(e, destArena, seenTypes, seenTypePacks, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1e97705dc..e332f07d4 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,10 +18,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) -LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) -LUAU_FASTFLAG(LuauNewRequireTrace2) namespace Luau { @@ -96,10 +93,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -110,7 +108,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -427,15 +425,16 @@ CheckResult Frontend::check(const ModuleName& name) SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); } stats.timeCheck += getTimestamp() - timestamp; @@ -885,16 +884,13 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module // If we can't find the current module name, that's because we bypassed the frontend's initializer // and called typeChecker.check directly. (This is done by autocompleteSource, for example). // In that case, requires will always fail. - if (FFlag::LuauResolveModuleNameWithoutACurrentModule) - return std::nullopt; - else - throw std::runtime_error("Frontend::resolveModuleName: Unknown currentModuleName '" + currentModuleName + "'"); + return std::nullopt; } const auto& exprs = it->second.exprs; const ModuleInfo* info = exprs.find(&pathExpr); - if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty())) + if (!info) return std::nullopt; return *info; @@ -911,10 +907,7 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - if (FFlag::LuauNewRequireTrace2) - return frontend->sourceNodes.count(moduleName) != 0; - else - return frontend->fileResolver->moduleExists(moduleName); + return frontend->sourceNodes.count(moduleName) != 0; } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 3b2671213..ac46b5a49 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,8 +2,6 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace Luau { @@ -94,7 +92,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) + if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty()) { stream << "<"; bool first = true; @@ -108,17 +106,14 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId t : error.typeFun.typePackParams) { - for (TypePackId t : error.typeFun.typePackParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(t); - } + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 064accba5..c7f623eea 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,8 +5,6 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace Luau { @@ -615,12 +613,7 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); - - if (FFlag::LuauTypeAliasPacks) - { - PROP(genericPacks); - } - + PROP(genericPacks); PROP(type); PROP(exported); }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 32a0646ae..b4b6eb425 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,20 +1,20 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Common.h" +#include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/Common.h" #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) -LUAU_FASTFLAG(LuauTypeAliasPacks) -LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau { @@ -120,12 +120,6 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); - namespace { @@ -138,11 +132,12 @@ struct TypePackCloner; struct TypeCloner { - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) : dest(dest) , typeId(typeId) , seenTypes(seenTypes) , seenTypePacks(seenTypePacks) + , cloneState(cloneState) { } @@ -150,8 +145,7 @@ struct TypeCloner TypeId typeId; SeenTypes& seenTypes; SeenTypePacks& seenTypePacks; - - bool* encounteredFreeType = nullptr; + CloneState& cloneState; template void defaultClone(const T& t); @@ -178,13 +172,14 @@ struct TypePackCloner TypePackId typePackId; SeenTypes& seenTypes; SeenTypePacks& seenTypePacks; - bool* encounteredFreeType = nullptr; + CloneState& cloneState; - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) : dest(dest) , typePackId(typePackId) , seenTypes(seenTypes) , seenTypePacks(seenTypePacks) + , cloneState(cloneState) { } @@ -197,8 +192,7 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - if (encounteredFreeType) - *encounteredFreeType = true; + cloneState.encounteredFreeType = true; TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); TypePackId cloned = dest.addTypePack(*err); @@ -218,13 +212,13 @@ struct TypePackCloner // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypePacks[typePackId] = cloned; } void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); seenTypePacks[typePackId] = cloned; } @@ -236,10 +230,10 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, encounteredFreeType); + destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); } }; @@ -252,8 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - if (encounteredFreeType) - *encounteredFreeType = true; + cloneState.encounteredFreeType = true; TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; @@ -266,7 +259,7 @@ void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Bound& t) { - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypes[typeId] = boundTo; } @@ -294,23 +287,23 @@ void TypeCloner::operator()(const FunctionTypeVar& t) seenTypes[typeId] = result; for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, encounteredFreeType)); + ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); + ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const TableTypeVar& t) { // If table is now bound to another one, we ignore the content of the original - if (FFlag::LuauCloneBoundTables && t.boundTo) + if (t.boundTo) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypes[typeId] = boundTo; return; } @@ -326,34 +319,21 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - - if (!FFlag::LuauCloneBoundTables) - { - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); - } + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), + clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); - } + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); if (ttv->state == TableState::Free) { - if (FFlag::LuauCloneBoundTables || !t.boundTo) - { - if (encounteredFreeType) - *encounteredFreeType = true; - } + cloneState.encounteredFreeType = true; ttv->state = TableState::Sealed; } @@ -369,8 +349,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t) MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, encounteredFreeType); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); + mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); + mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const ClassTypeVar& t) @@ -381,13 +361,13 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); + ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); + ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const AnyTypeVar& t) @@ -404,7 +384,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); } void TypeCloner::operator()(const IntersectionTypeVar& t) @@ -416,7 +396,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); } void TypeCloner::operator()(const LazyTypeVar& t) @@ -426,17 +406,18 @@ void TypeCloner::operator()(const LazyTypeVar& t) } // anonymous namespace -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { if (tp->persistent) return tp; + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + TypePackId& res = seenTypePacks[tp]; if (res == nullptr) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; + TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } @@ -446,17 +427,18 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP return res; } -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { if (typeId->persistent) return typeId; + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + TypeId& res = seenTypes[typeId]; if (res == nullptr) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; + TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. asMutable(res)->documentationSymbol = typeId->documentationSymbol; } @@ -467,19 +449,16 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks return res; } -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); - } + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); + result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); return result; } @@ -519,19 +498,18 @@ bool Module::clonePublicInterface() LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - bool encounteredFreeType = false; - - SeenTypePacks seenTypePacks; SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + CloneState cloneState; ScopePtr moduleScope = getModuleScope(); - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) @@ -540,7 +518,7 @@ bool Module::clonePublicInterface() freeze(internalTypes); freeze(interfaceTypes); - return encounteredFreeType; + return cloneState.encounteredFreeType; } } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index bf6d81aa6..c773e208b 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,6 +4,8 @@ #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAGVARIABLE(LuauQuantifyVisitOnce, false) + namespace Luau { @@ -79,7 +81,16 @@ struct Quantifier void quantify(ModulePtr module, TypeId ty, TypeLevel level) { Quantifier q{std::move(module), level}; - visitTypeVar(ty, q); + + if (FFlag::LuauQuantifyVisitOnce) + { + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, q, seen); + } + else + { + visitTypeVar(ty, q); + } FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index b72f53f99..8ed245fbb 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -4,182 +4,9 @@ #include "Luau/Ast.h" #include "Luau/Module.h" -LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) -LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false) - namespace Luau { -namespace -{ - -struct RequireTracerOld : AstVisitor -{ - explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) - : fileResolver(fileResolver) - , currentModuleName(currentModuleName) - { - LUAU_ASSERT(!FFlag::LuauNewRequireTrace2); - } - - FileResolver* const fileResolver; - ModuleName currentModuleName; - DenseHashMap locals{nullptr}; - RequireTraceResult result; - - std::optional fromAstFragment(AstExpr* expr) - { - if (auto g = expr->as(); g && g->name == "script") - return currentModuleName; - - return fileResolver->fromAstFragment(expr); - } - - bool visit(AstStatLocal* stat) override - { - for (size_t i = 0; i < stat->vars.size; ++i) - { - AstLocal* local = stat->vars.data[i]; - - if (local->annotation) - { - if (AstTypeTypeof* ann = local->annotation->as()) - ann->expr->visit(this); - } - - if (i < stat->values.size) - { - AstExpr* expr = stat->values.data[i]; - expr->visit(this); - - const ModuleInfo* info = result.exprs.find(expr); - if (info) - locals[local] = info->name; - } - } - - return false; - } - - bool visit(AstExprGlobal* global) override - { - std::optional name = fromAstFragment(global); - if (name) - result.exprs[global] = {*name}; - - return false; - } - - bool visit(AstExprLocal* local) override - { - const ModuleName* name = locals.find(local->local); - if (name) - result.exprs[local] = {*name}; - - return false; - } - - bool visit(AstExprIndexName* indexName) override - { - indexName->expr->visit(this); - - const ModuleInfo* info = result.exprs.find(indexName->expr); - if (info) - { - if (indexName->index == "parent" || indexName->index == "Parent") - { - if (auto parent = fileResolver->getParentModuleName(info->name)) - result.exprs[indexName] = {*parent}; - } - else - result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; - } - - return false; - } - - bool visit(AstExprIndexExpr* indexExpr) override - { - indexExpr->expr->visit(this); - - const ModuleInfo* info = result.exprs.find(indexExpr->expr); - const AstExprConstantString* str = indexExpr->index->as(); - if (info && str) - { - result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; - } - - indexExpr->index->visit(this); - - return false; - } - - bool visit(AstExprTypeAssertion* expr) override - { - return false; - } - - // If we see game:GetService("StringLiteral") or Game:GetService("StringLiteral"), then rewrite to game.StringLiteral. - // Else traverse arguments and trace requires to them. - bool visit(AstExprCall* call) override - { - for (AstExpr* arg : call->args) - arg->visit(this); - - call->func->visit(this); - - AstExprGlobal* globalName = call->func->as(); - if (globalName && globalName->name == "require" && call->args.size >= 1) - { - if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) - result.requires.push_back({moduleInfo->name, call->location}); - - return false; - } - - AstExprIndexName* indexName = call->func->as(); - if (!indexName) - return false; - - std::optional rootName = fromAstFragment(indexName->expr); - - if (FFlag::LuauTraceRequireLookupChild && !rootName) - { - if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) - rootName = moduleInfo->name; - } - - if (!rootName) - return false; - - bool supportedLookup = indexName->index == "GetService" || - (FFlag::LuauTraceRequireLookupChild && (indexName->index == "FindFirstChild" || indexName->index == "WaitForChild")); - - if (!supportedLookup) - return false; - - if (call->args.size != 1) - return false; - - AstExprConstantString* name = call->args.data[0]->as(); - if (!name) - return false; - - std::string_view v{name->value.data, name->value.size}; - if (v.end() != std::find(v.begin(), v.end(), '/')) - return false; - - result.exprs[call] = {fileResolver->concat(*rootName, v)}; - - // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime - // If we fail to find such module, we will not report an UnknownRequire error - if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.exprs[call].optional = true; - - return false; - } -}; - struct RequireTracer : AstVisitor { RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) @@ -188,7 +15,6 @@ struct RequireTracer : AstVisitor , currentModuleName(currentModuleName) , locals(nullptr) { - LUAU_ASSERT(FFlag::LuauNewRequireTrace2); } bool visit(AstExprTypeAssertion* expr) override @@ -328,24 +154,13 @@ struct RequireTracer : AstVisitor std::vector requires; }; -} // anonymous namespace - RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - if (FFlag::LuauNewRequireTrace2) - { - RequireTraceResult result; - RequireTracer tracer{result, fileResolver, currentModuleName}; - root->visit(&tracer); - tracer.process(); - return result; - } - else - { - RequireTracerOld tracer{fileResolver, currentModuleName}; - root->visit(&tracer); - return tracer.result; - } + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index ca2b30f52..3d004bee3 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -7,8 +7,6 @@ #include LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) -LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) -LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -39,11 +37,8 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp); - } + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -339,10 +334,10 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + if (!ignoreChildren(oldTy)) replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + if (!ignoreChildren(oldTp)) replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; @@ -359,10 +354,10 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + if (!ignoreChildren(oldTy)) replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + if (!ignoreChildren(oldTp)) replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; @@ -393,10 +388,7 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - - if (FFlag::LuauTypeAliasPacks) - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; clone.tags = ttv->tags; result = addType(std::move(clone)); } @@ -505,11 +497,8 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId& itp : ttv->instantiatedTypePackParams) - itp = replace(itp); - } + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp new file mode 100644 index 000000000..df9d41881 --- /dev/null +++ b/Analysis/src/ToDot.cpp @@ -0,0 +1,378 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ToDot.h" + +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/StringUtils.h" + +#include +#include + +namespace Luau +{ + +namespace +{ + +struct StateDot +{ + StateDot(ToDotOptions opts) + : opts(opts) + { + } + + ToDotOptions opts; + + std::unordered_set seenTy; + std::unordered_set seenTp; + std::unordered_map tyToIndex; + std::unordered_map tpToIndex; + int nextIndex = 1; + std::string result; + + bool canDuplicatePrimitive(TypeId ty); + + void visitChildren(TypeId ty, int index); + void visitChildren(TypePackId ty, int index); + + void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); + void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); + + void startNode(int index); + void finishNode(); + + void startNodeLabel(); + void finishNodeLabel(TypeId ty); + void finishNodeLabel(TypePackId tp); +}; + +bool StateDot::canDuplicatePrimitive(TypeId ty) +{ + if (get(ty)) + return false; + + return get(ty) || get(ty); +} + +void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) +{ + if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) + tyToIndex[ty] = nextIndex++; + + int index = tyToIndex[ty]; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, index); + } + + if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) + { + if (get(ty)) + formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); + else if (get(ty)) + formatAppend(result, "n%d [label=\"any\"];\n", index); + } + else + { + visitChildren(ty, index); + } +} + +void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) +{ + if (!tpToIndex.count(tp)) + tpToIndex[tp] = nextIndex++; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); + } + + visitChildren(tp, tpToIndex[tp]); +} + +void StateDot::startNode(int index) +{ + formatAppend(result, "n%d [", index); +} + +void StateDot::finishNode() +{ + formatAppend(result, "];\n"); +} + +void StateDot::startNodeLabel() +{ + formatAppend(result, "label=\""); +} + +void StateDot::finishNodeLabel(TypeId ty) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", ty); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::finishNodeLabel(TypePackId tp) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", tp); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::visitChildren(TypeId ty, int index) +{ + if (seenTy.count(ty)) + return; + seenTy.insert(ty); + + startNode(index); + startNodeLabel(); + + if (const BoundTypeVar* btv = get(ty)) + { + formatAppend(result, "BoundTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(btv->boundTo, index); + } + else if (const FunctionTypeVar* ftv = get(ty)) + { + formatAppend(result, "FunctionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(ftv->argTypes, index, "arg"); + visitChild(ftv->retType, index, "ret"); + } + else if (const TableTypeVar* ttv = get(ty)) + { + if (ttv->name) + formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); + else if (ttv->syntheticName) + formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); + else + formatAppend(result, "TableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + if (ttv->boundTo) + return visitChild(*ttv->boundTo, index, "boundTo"); + + for (const auto& [name, prop] : ttv->props) + visitChild(prop.type, index, name.c_str()); + if (ttv->indexer) + { + visitChild(ttv->indexer->indexType, index, "[index]"); + visitChild(ttv->indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : ttv->instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + formatAppend(result, "MetatableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(mtv->table, index, "table"); + visitChild(mtv->metatable, index, "metatable"); + } + else if (const UnionTypeVar* utv = get(ty)) + { + formatAppend(result, "UnionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : utv->options) + visitChild(opt, index); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + formatAppend(result, "IntersectionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : itv->parts) + visitChild(part, index); + } + else if (const GenericTypeVar* gtv = get(ty)) + { + if (gtv->explicitName) + formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); + else + formatAppend(result, "GenericTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const FreeTypeVar* ftv = get(ty)) + { + formatAppend(result, "FreeTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "AnyTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "ErrorTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const ClassTypeVar* ctv = get(ty)) + { + formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : ctv->props) + visitChild(prop.type, index, name.c_str()); + + if (ctv->parent) + visitChild(*ctv->parent, index, "[parent]"); + + if (ctv->metatable) + visitChild(*ctv->metatable, index, "[metatable]"); + } + else + { + LUAU_ASSERT(!"unknown type kind"); + finishNodeLabel(ty); + finishNode(); + } +} + +void StateDot::visitChildren(TypePackId tp, int index) +{ + if (seenTp.count(tp)) + return; + seenTp.insert(tp); + + startNode(index); + startNodeLabel(); + + if (const BoundTypePack* btp = get(tp)) + { + formatAppend(result, "BoundTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(btp->boundTo, index); + } + else if (const TypePack* tpp = get(tp)) + { + formatAppend(result, "TypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + for (TypeId tv : tpp->head) + visitChild(tv, index); + if (tpp->tail) + visitChild(*tpp->tail, index, "tail"); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + formatAppend(result, "VariadicTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(vtp->ty, index); + } + else if (const FreeTypePack* ftp = get(tp)) + { + formatAppend(result, "FreeTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (const GenericTypePack* gtp = get(tp)) + { + if (gtp->explicitName) + formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); + else + formatAppend(result, "GenericTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (get(tp)) + { + formatAppend(result, "ErrorTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else + { + LUAU_ASSERT(!"unknown type pack kind"); + finishNodeLabel(tp); + finishNode(); + } +} + +} // namespace + +std::string toDot(TypeId ty, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(ty, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypePackId tp, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(tp, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypeId ty) +{ + return toDot(ty, {}); +} + +std::string toDot(TypePackId tp) +{ + return toDot(tp, {}); +} + +void dumpDot(TypeId ty) +{ + printf("%s\n", toDot(ty).c_str()); +} + +void dumpDot(TypePackId tp) +{ + printf("%s\n", toDot(tp).c_str()); +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 735bfa503..6322096c4 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) namespace Luau { @@ -59,11 +59,8 @@ struct FindCyclicTypes for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv.instantiatedTypePackParams) - visitTypeVar(itp, *this, seen); - } + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); return exhaustive; } @@ -248,58 +245,45 @@ struct TypeVarStringifier void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) + if (types.size() == 0 && typePacks.size() == 0) return; - if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) + if (types.size() || typePacks.size()) state.emit("<"); - if (FFlag::LuauTypeAliasPacks) - { - bool first = true; - - for (TypeId ty : types) - { - if (!first) - state.emit(", "); - first = false; + bool first = true; - stringify(ty); - } + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; - bool singleTp = typePacks.size() == 1; + stringify(ty); + } - for (TypePackId tp : typePacks) - { - if (isEmpty(tp) && singleTp) - continue; + bool singleTp = typePacks.size() == 1; - if (!first) - state.emit(", "); - else - first = false; + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; - if (!singleTp) - state.emit("("); + if (!first) + state.emit(", "); + else + first = false; - stringify(tp); + if (!singleTp) + state.emit("("); - if (!singleTp) - state.emit(")"); - } - } - else - { - for (size_t i = 0; i < types.size(); ++i) - { - if (i > 0) - state.emit(", "); + stringify(tp); - stringify(types[i]); - } + if (!singleTp) + state.emit(")"); } - if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) + if (types.size() || typePacks.size()) state.emit(">"); } @@ -767,12 +751,23 @@ struct TypePackStringifier else state.emit(", "); - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - - if (!elemNames.empty() && elemNames[elemIndex]) + if (FFlag::LuauFunctionArgumentNameSize) { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); + if (elemIndex < elemNames.size() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } + } + else + { + LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); + + if (!elemNames.empty() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } } elemIndex++; @@ -929,38 +924,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (FFlag::LuauTypeAliasPacks) - { - tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); - } - else - { - if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) - { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; - } - else - { - result.name += ">"; - } - } + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); return result; } @@ -1161,17 +1125,37 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV s += ", "; first = false; - // argNames is guaranteed to be equal to argTypes iff argNames is not empty. - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (!ftv.argNames.empty()) - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - s += toString_(*argPackIter); + if (FFlag::LuauFunctionArgumentNameSize) + { + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + ++argNameIter; + } + else + { + s += "_: "; + } + } + else + { + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + } + s += toString_(*argPackIter); ++argPackIter; - if (!ftv.argNames.empty()) + + if (!FFlag::LuauFunctionArgumentNameSize) { - LUAU_ASSERT(argNameIter != ftv.argNames.end()); - ++argNameIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } } } diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 6627fbe36..8e13ea5be 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace { bool isIdentifierStartChar(char c) @@ -787,7 +785,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -798,14 +796,11 @@ struct Printer writer.identifier(o.value); } - if (FFlag::LuauTypeAliasPacks) + for (auto o : a->genericPacks) { - for (auto o : a->genericPacks) - { - comma(); - writer.identifier(o.value); - writer.symbol("..."); - } + comma(); + writer.identifier(o.value); + writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 383bb050d..f6a61581e 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) - namespace Luau { @@ -36,11 +34,8 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); - if (FFlag::LuauShareTxnSeen) - { - LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); - sharedSeen->resize(originalSeenSize); - } + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); } void TxnLog::concat(TxnLog rhs) @@ -53,45 +48,25 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - - if (!FFlag::LuauShareTxnSeen) - { - ownedSeen.swap(rhs.ownedSeen); - rhs.ownedSeen.clear(); - } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); - else - return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - sharedSeen->push_back(sortedPair); - else - ownedSeen.push_back(sortedPair); + sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - { - LUAU_ASSERT(sortedPair == sharedSeen->back()); - sharedSeen->pop_back(); - } - else - { - LUAU_ASSERT(sortedPair == ownedSeen.back()); - ownedSeen.pop_back(); - } + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index af6d2543d..9e61c7924 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { char* result = (char*)allocator.allocate(contents.size() + 1); @@ -131,12 +129,9 @@ class TypeRehydrationVisitor parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; } - if (FFlag::LuauTypeAliasPacks) + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) { - for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) - { - parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; - } + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; } return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); @@ -250,20 +245,7 @@ class TypeRehydrationVisitor AstTypePack* argTailAnnotation = nullptr; if (argTail) - { - if (FFlag::LuauTypeAliasPacks) - { - argTailAnnotation = rehydrate(*argTail); - } - else - { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) - { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } - } + argTailAnnotation = rehydrate(*argTail); AstArray> argNames; argNames.size = ftv.argNames.size(); @@ -292,20 +274,7 @@ class TypeRehydrationVisitor AstTypePack* retTailAnnotation = nullptr; if (retTail) - { - if (FFlag::LuauTypeAliasPacks) - { - retTailAnnotation = rehydrate(*retTail); - } - else - { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) - { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } - } + retTailAnnotation = rehydrate(*retTail); return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); @@ -518,18 +487,7 @@ class TypeAttacher : public AstVisitor const auto& [v, tail] = flatten(ret); if (tail) - { - if (FFlag::LuauTypeAliasPacks) - { - variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); - } - else - { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); - } - } + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b2ae94c72..617bf482c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -23,22 +23,20 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) -LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) -LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) +LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) +LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) +LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) namespace Luau { @@ -562,12 +560,6 @@ ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location return canUnify_(left, right, location); } -ErrorVec TypeChecker::canUnify(const std::vector>& seen, TypeId superTy, TypeId subTy, const Location& location) -{ - Unifier state = mkUnifier(seen, location); - return state.canUnify(superTy, subTy); -} - template ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) { @@ -1152,61 +1144,20 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - else - bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; } else { ScopePtr aliasScope = FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); - if (FFlag::LuauTypeAliasPacks) - { - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); - TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - } - else - { - std::vector generics; - for (AstName generic : typealias.generics) - { - Name n = generic.value; - - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) - { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } - - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasTypeParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; - } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; - } - - TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; - } + TypeId ty = freshType(aliasScope); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; } } else @@ -1223,14 +1174,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId tp : binding->typePackParams) { - for (TypePackId tp : binding->typePackParams) - { - auto generic = get(tp); - LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; - } + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1241,19 +1189,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // Copy can be skipped if this is an identical alias if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) + ttv->instantiatedTypePackParams != binding->typePackParams) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; - clone.name = name; clone.instantiatedTypeParams = binding->typeParams; - - if (FFlag::LuauTypeAliasPacks) - clone.instantiatedTypePackParams = binding->typePackParams; + clone.instantiatedTypePackParams = binding->typePackParams; ty = addType(std::move(clone)); } @@ -1262,9 +1207,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = binding->typePackParams; + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1289,7 +1232,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; if (!get(follow(*superTy))) @@ -1851,6 +1794,24 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; + if (FFlag::LuauPropertiesGetExpectedType && expectedTable) + { + auto it = expectedTable->props.find(key->value.data); + if (it != expectedTable->props.end()) + { + Property expectedProp = it->second; + ErrorVec errors = tryUnify(expectedProp.type, exprType, k->location); + if (errors.empty()) + exprType = expectedProp.type; + } + else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) + { + ErrorVec errors = tryUnify(expectedTable->indexer->indexResultType, exprType, k->location); + if (errors.empty()) + exprType = expectedTable->indexer->indexResultType; + } + } + props[key->value.data] = {exprType, /* deprecated */ false, {}, k->location}; } else @@ -3744,17 +3705,29 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; + std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; if (i == lastIndex && (expr->is() || expr->is())) { auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); + if (FFlag::LuauTailArgumentTypeInfo) + { + if (std::optional firstTy = first(typePack)) + { + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); + } + + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + } + tp->tail = typePack; } else { - std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); @@ -3797,7 +3770,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); - if (FFlag::LuauNewRequireTrace2 && moduleInfo.name.empty()) + if (moduleInfo.name.empty()) { if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { @@ -3814,7 +3787,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // There are two reasons why we might fail to find the module: // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. - if (!resolver->moduleExists(moduleInfo.name) && (FFlag::LuauTraceRequireLookupChild ? !moduleInfo.optional : true)) + if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) { std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(TypeError{location, UnknownRequire{reportedModulePath}}); @@ -3830,7 +3803,12 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - std::optional moduleType = first(module->getModuleScope()->returnType); + TypePackId modulePack = module->getModuleScope()->returnType; + + if (FFlag::LuauModuleRequireErrorPack && get(modulePack)) + return errorRecoveryType(scope); + + std::optional moduleType = first(modulePack); if (!moduleType) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); @@ -3840,7 +3818,8 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module SeenTypes seenTypes; SeenTypePacks seenTypePacks; - return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks); + CloneState cloneState; + return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks, cloneState); } void TypeChecker::tablify(TypeId type) @@ -4326,11 +4305,6 @@ Unifier TypeChecker::mkUnifier(const Location& location) return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } -Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) -{ - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; -} - TypeId TypeChecker::freshType(const ScopePtr& scope) { return freshType(scope->level); @@ -4477,117 +4451,82 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) - { + if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - } - else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) + + if (!lit->hasParameterList && !tf->typePackParams.empty()) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); if (!FFlag::LuauErrorRecoveryType) return errorRecoveryType(scope); } - if (FFlag::LuauTypeAliasPacks) - { - if (!lit->hasParameterList && !tf->typePackParams.empty()) - { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); - } - - std::vector typeParams; - std::vector extraTypes; - std::vector typePackParams; + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; - for (size_t i = 0; i < lit->parameters.size; ++i) + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) { - if (AstType* type = lit->parameters.data[i].type) - { - TypeId ty = resolveType(scope, *type); + TypeId ty = resolveType(scope, *type); - if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) - typeParams.push_back(ty); - else if (typePackParams.empty()) - extraTypes.push_back(ty); - else - reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); - } - else if (AstTypePack* typePack = lit->parameters.data[i].typePack) - { - TypePackId tp = resolveTypePack(scope, *typePack); - - // If we have collected an implicit type pack, materialize it - if (typePackParams.empty() && !extraTypes.empty()) - typePackParams.push_back(addTypePack(extraTypes)); - - // If we need more regular types, we can use single element type packs to fill those in - if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) - typeParams.push_back(*first(tp)); - else - typePackParams.push_back(tp); - } + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); } - - // If we still haven't meterialized an implicit type pack, do it now - if (typePackParams.empty() && !extraTypes.empty()) - typePackParams.push_back(addTypePack(extraTypes)); - - // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack - if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) - typePackParams.push_back(addTypePack({})); - - if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + TypePackId tp = resolveTypePack(scope, *typePack); - if (FFlag::LuauErrorRecoveryType) - { - // Pad the types out with error recovery types - while (typeParams.size() < tf->typeParams.size()) - typeParams.push_back(errorRecoveryType(scope)); - while (typePackParams.size() < tf->typePackParams.size()) - typePackParams.push_back(errorRecoveryTypePack(scope)); - } + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); else - return errorRecoveryType(scope); + typePackParams.push_back(tp); } + } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) - { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; - } + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); - return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); - } - else - { - std::vector typeParams; + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); - for (const auto& param : lit->parameters) - typeParams.push_back(resolveType(scope, *param.type)); + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { - // If there aren't enough type parameters, pad them out with error recovery types - // (we've already reported the error) - while (typeParams.size() < lit->parameters.size) + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); } + else + return errorRecoveryType(scope); + } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) - { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; - } - - return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } else if (const auto& table = annotation.as()) { @@ -4757,7 +4696,7 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) bool ApplyTypeFunction::ignoreChildren(TypeId ty) { - if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + if (get(ty)) return true; else return false; @@ -4765,7 +4704,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty) bool ApplyTypeFunction::ignoreChildren(TypePackId tp) { - if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + if (get(tp)) return true; else return false; @@ -4788,36 +4727,26 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - if (FFlag::LuauTypeAliasPacks) - { - TypePackId& arg = typePackArguments[tp]; - if (arg) - return arg; - else - return addTypePack(FreeTypePack{level}); - } + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; else - { return addTypePack(FreeTypePack{level}); - } } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) + if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; - if (FFlag::LuauTypeAliasPacks) - { - applyTypeFunction.typePackArguments.clear(); - for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; - } + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -4866,9 +4795,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = typePackParams; + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4884,9 +4811,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = typePackParams; + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4914,7 +4839,7 @@ std::pair, std::vector> TypeChecker::createGener } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + if (FFlag::LuauRecursiveTypeParameterRestriction) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) @@ -4944,7 +4869,7 @@ std::pair, std::vector> TypeChecker::createGener } TypePackId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + if (FFlag::LuauRecursiveTypeParameterRestriction) { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) @@ -5245,7 +5170,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) + if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 924bf082a..62715af53 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,7 +19,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) @@ -739,369 +738,6 @@ void persist(TypePackId tp) } } -namespace -{ - -struct StateDot -{ - StateDot(ToDotOptions opts) - : opts(opts) - { - } - - ToDotOptions opts; - - std::unordered_set seenTy; - std::unordered_set seenTp; - std::unordered_map tyToIndex; - std::unordered_map tpToIndex; - int nextIndex = 1; - std::string result; - - bool canDuplicatePrimitive(TypeId ty); - - void visitChildren(TypeId ty, int index); - void visitChildren(TypePackId ty, int index); - - void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); - void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); - - void startNode(int index); - void finishNode(); - - void startNodeLabel(); - void finishNodeLabel(TypeId ty); - void finishNodeLabel(TypePackId tp); -}; - -bool StateDot::canDuplicatePrimitive(TypeId ty) -{ - if (get(ty)) - return false; - - return get(ty) || get(ty); -} - -void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) -{ - if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) - tyToIndex[ty] = nextIndex++; - - int index = tyToIndex[ty]; - - if (parentIndex != 0) - { - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, index); - } - - if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) - { - if (get(ty)) - formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (get(ty)) - formatAppend(result, "n%d [label=\"any\"];\n", index); - } - else - { - visitChildren(ty, index); - } -} - -void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) -{ - if (!tpToIndex.count(tp)) - tpToIndex[tp] = nextIndex++; - - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); - - visitChildren(tp, tpToIndex[tp]); -} - -void StateDot::startNode(int index) -{ - formatAppend(result, "n%d [", index); -} - -void StateDot::finishNode() -{ - formatAppend(result, "];\n"); -} - -void StateDot::startNodeLabel() -{ - formatAppend(result, "label=\""); -} - -void StateDot::finishNodeLabel(TypeId ty) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", ty); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::finishNodeLabel(TypePackId tp) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", tp); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::visitChildren(TypeId ty, int index) -{ - if (seenTy.count(ty)) - return; - seenTy.insert(ty); - - startNode(index); - startNodeLabel(); - - if (const BoundTypeVar* btv = get(ty)) - { - formatAppend(result, "BoundTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(btv->boundTo, index); - } - else if (const FunctionTypeVar* ftv = get(ty)) - { - formatAppend(result, "FunctionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); - } - else if (const TableTypeVar* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); - - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); - if (ttv->indexer) - { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); - } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); - - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - } - else if (const MetatableTypeVar* mtv = get(ty)) - { - formatAppend(result, "MetatableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionTypeVar* utv = get(ty)) - { - formatAppend(result, "UnionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - formatAppend(result, "IntersectionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericTypeVar* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); - else - formatAppend(result, "GenericTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeTypeVar* ftv = get(ty)) - { - formatAppend(result, "FreeTypeVar %d", ftv->index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassTypeVar* ctv = get(ty)) - { - formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); - - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); - - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); - - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } -} - -void StateDot::visitChildren(TypePackId tp, int index) -{ - if (seenTp.count(tp)) - return; - seenTp.insert(tp); - - startNode(index); - startNodeLabel(); - - if (const BoundTypePack* btp = get(tp)) - { - formatAppend(result, "BoundTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(btp->boundTo, index); - } - else if (const TypePack* tpp = get(tp)) - { - formatAppend(result, "TypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - for (TypeId tv : tpp->head) - visitChild(tv, index); - if (tpp->tail) - visitChild(*tpp->tail, index, "tail"); - } - else if (const VariadicTypePack* vtp = get(tp)) - { - formatAppend(result, "VariadicTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(vtp->ty, index); - } - else if (const FreeTypePack* ftp = get(tp)) - { - formatAppend(result, "FreeTypePack %d", ftp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (const GenericTypePack* gtp = get(tp)) - { - if (gtp->explicitName) - formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); - else - formatAppend(result, "GenericTypePack %d", gtp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (get(tp)) - { - formatAppend(result, "ErrorTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type pack kind"); - finishNodeLabel(tp); - finishNode(); - } -} - -} // namespace - -std::string toDot(TypeId ty, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(ty, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypePackId tp, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(tp, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypeId ty) -{ - return toDot(ty, {}); -} - -std::string toDot(TypePackId tp) -{ - return toDot(tp, {}); -} - -void dumpDot(TypeId ty) -{ - printf("%s\n", toDot(ty).c_str()); -} - -void dumpDot(TypePackId tp) -{ - printf("%s\n", toDot(tp).c_str()); -} - const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index e1a52be4e..d0b188374 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,9 +18,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) -LUAU_FASTFLAG(LuauShareTxnSeen); -LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) @@ -136,38 +133,19 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(&countersData) - , counters_DEPRECATED(std::make_shared()) - , sharedState(sharedState) -{ - LUAU_ASSERT(sharedState.iceHandler); -} - -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) - : types(types) - , mode(mode) - , globalScope(std::move(globalScope)) - , log(ownedSeen) - , location(location) - , variance(variance) - , counters(counters ? counters : &countersData) - , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , sharedState(sharedState) { LUAU_ASSERT(sharedState.iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(sharedSeen) , location(location) , variance(variance) - , counters(counters ? counters : &countersData) - , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , sharedState(sharedState) { LUAU_ASSERT(sharedState.iceHandler); @@ -175,26 +153,18 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector< void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - if (FFlag::LuauTypecheckOpts) - counters->iterationCount = 0; - else - counters_DEPRECATED->iterationCount = 0; + sharedState.counters.iterationCount = 0; tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - if (FFlag::LuauTypecheckOpts) - ++counters->iterationCount; - else - ++counters_DEPRECATED->iterationCount; + ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && - FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -302,7 +272,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); - bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + bool cacheEnabled = !isFunctionCall && !isIntersection; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before @@ -563,8 +533,6 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool void Unifier::cacheResult(TypeId superTy, TypeId subTy) { - LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); - bool* superTyInfo = sharedState.skipCacheForType.find(superTy); if (superTyInfo && *superTyInfo) @@ -686,10 +654,7 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - if (FFlag::LuauTypecheckOpts) - counters->iterationCount = 0; - else - counters_DEPRECATED->iterationCount = 0; + sharedState.counters.iterationCount = 0; tryUnify_(superTp, subTp, isFunctionCall); } @@ -700,16 +665,11 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - if (FFlag::LuauTypecheckOpts) - ++counters->iterationCount; - else - ++counters_DEPRECATED->iterationCount; + ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && - FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -1727,39 +1687,8 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack_DEPRECATED( - std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOpts); - - while (true) - { - a = follow(a); - - if (seenTypePacks.count(a)) - break; - seenTypePacks.insert(a); - - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } - } -} - static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { - LUAU_ASSERT(FFlag::LuauTypecheckOpts); - while (true) { a = follow(a); @@ -1837,66 +1766,9 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny_DEPRECATED( - std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOpts); - - std::unordered_set seen; - - while (!queue.empty()) - { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.count(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); - - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. - } -} - static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { - LUAU_ASSERT(FFlag::LuauTypecheckOpts); - while (!queue.empty()) { TypeId ty = follow(queue.back()); @@ -1949,43 +1821,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); - if (FFlag::LuauTypecheckOpts) - { - // These types are not visited in general loop below - if (get(ty) || get(ty) || get(ty)) - return; - } + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - if (FFlag::LuauTypecheckOpts) - { - std::vector queue = {ty}; - - if (FFlag::LuauCacheUnifyTableResults) - { - sharedState.tempSeenTy.clear(); - sharedState.tempSeenTp.clear(); - - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); - } - else - { - tempSeenTy_DEPRECATED.clear(); - tempSeenTp_DEPRECATED.clear(); + std::vector queue = {ty}; - Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); - } - } - else - { - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); - } + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1994,38 +1843,14 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorRecoveryType(); - if (FFlag::LuauTypecheckOpts) - { - std::vector queue; - - if (FFlag::LuauCacheUnifyTableResults) - { - sharedState.tempSeenTy.clear(); - sharedState.tempSeenTp.clear(); - - queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); - } - else - { - tempSeenTy_DEPRECATED.clear(); - tempSeenTp_DEPRECATED.clear(); - - queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); + std::vector queue; - Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); - } - } - else - { - std::unordered_set seenTypePacks; - std::vector queue; + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); - } + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -2035,46 +1860,22 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen_DEPRECATED; + sharedState.tempSeenTy.clear(); - if (FFlag::LuauCacheUnifyTableResults) - { - if (FFlag::LuauTypecheckOpts) - sharedState.tempSeenTy.clear(); - - return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); - } - else - { - if (FFlag::LuauTypecheckOpts) - tempSeenTy_DEPRECATED.clear(); - - return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); - } + return occursCheck(sharedState.tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (FFlag::LuauTypecheckOpts) - { - if (seen.find(haystack)) - return; - - seen.insert(haystack); - } - else - { - if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen_DEPRECATED.insert(haystack); - } + seen.insert(haystack); if (get(needle)) return; @@ -2091,7 +1892,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash } auto check = [&](TypeId tv) { - occursCheck(seen_DEPRECATED, seen, needle, tv); + occursCheck(seen, needle, tv); }; if (get(haystack)) @@ -2121,43 +1922,20 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen_DEPRECATED; - - if (FFlag::LuauCacheUnifyTableResults) - { - if (FFlag::LuauTypecheckOpts) - sharedState.tempSeenTp.clear(); - - return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); - } - else - { - if (FFlag::LuauTypecheckOpts) - tempSeenTp_DEPRECATED.clear(); + sharedState.tempSeenTp.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); - } + return occursCheck(sharedState.tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (FFlag::LuauTypecheckOpts) - { - if (seen.find(haystack)) - return; - - seen.insert(haystack); - } - else - { - if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen_DEPRECATED.insert(haystack); - } + seen.insert(haystack); if (get(needle)) return; @@ -2165,8 +1943,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -2186,8 +1963,8 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { if (auto f = get(follow(ty))) { - occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); - occursCheck(seen_DEPRECATED, seen, needle, f->retType); + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); } } } @@ -2204,10 +1981,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauShareTxnSeen) - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; - else - return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index bc63e37dc..3d0d5b7e6 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,8 +13,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) -LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) @@ -782,8 +780,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstType* type = parseTypeAnnotation(); - return allocator.alloc( - Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); + return allocator.alloc(Location(start, type->location), name->name, generics, genericPacks, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1602,30 +1599,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(Location(begin, end), expr), {}}; } - if (FFlag::LuauParseTypePackTypeParameters) - { - bool hasParameters = false; - AstArray parameters{}; - - if (lexer.current().type == '<') - { - hasParameters = true; - parameters = parseTypeParams(); - } + bool hasParameters = false; + AstArray parameters{}; - Location end = lexer.previousLocation(); - - return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; - } - else + if (lexer.current().type == '<') { - AstArray generics = parseTypeParams(); + hasParameters = true; + parameters = parseTypeParams(); + } - Location end = lexer.previousLocation(); + Location end = lexer.previousLocation(); - // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks - return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; - } + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; } else if (lexer.current().type == '{') { @@ -2414,37 +2399,24 @@ AstArray Parser::parseTypeParams() while (true) { - if (FFlag::LuauParseTypePackTypeParameters) + if (shouldParseTypePackAnnotation(lexer)) + { + auto typePack = parseTypePackAnnotation(); + + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') { - if (shouldParseTypePackAnnotation(lexer)) - { - auto typePack = parseTypePackAnnotation(); - - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them - parameters.push_back({{}, typePack}); - } - else if (lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPackAnnotation(); - - if (typePack) - { - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them - parameters.push_back({{}, typePack}); - } - else - { - parameters.push_back({type, {}}); - } - } - else if (lexer.current().type == '>' && parameters.empty()) - { - break; - } + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + parameters.push_back({{}, typePack}); else - { - parameters.push_back({parseTypeAnnotation(), {}}); - } + parameters.push_back({type, {}}); + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; } else { diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index ebdd78966..9230d80d0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -121,7 +121,7 @@ struct CliFileResolver : Luau::FileResolver if (Luau::AstExprConstantString* expr = node->as()) { Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; - if (!moduleExists(name)) + if (!readFile(name)) { // fall back to .lua if a module with .luau doesn't exist name = std::string(expr->value.data, expr->value.size) + ".lua"; @@ -132,27 +132,6 @@ struct CliFileResolver : Luau::FileResolver return std::nullopt; } - - bool moduleExists(const Luau::ModuleName& name) const override - { - return !!readFile(name); - } - - - std::optional fromAstFragment(Luau::AstExpr* expr) const override - { - return std::nullopt; - } - - Luau::ModuleName concat(const Luau::ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + std::string(rhs); - } - - std::optional getParentModuleName(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index b29cd6f9c..2cdd0062f 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -198,11 +198,6 @@ static std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); -#ifdef __EMSCRIPTEN__ - // nicer formatting for errors in web repl - error = "Error:" + error; -#endif - fprintf(stdout, "%s", error.c_str()); } @@ -210,39 +205,6 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } -#ifdef __EMSCRIPTEN__ -extern "C" -{ - const char* executeScript(const char* source) - { - // setup flags - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; - - // create new state - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - // setup state - setupState(L); - - // sandbox thread - luaL_sandboxthread(L); - - // static string for caching result (prevents dangling ptr on function exit) - static std::string result; - - // run code + collect error - result = runCode(L, source); - - return result.empty() ? NULL : result.c_str(); - } -} -#endif - -// Excluded from emscripten compilation to avoid -Wunused-function errors. -#ifndef __EMSCRIPTEN__ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; @@ -564,6 +526,5 @@ int main(int argc, char** argv) return failed; } } -#endif diff --git a/CLI/Web.cpp b/CLI/Web.cpp new file mode 100644 index 000000000..cf5c831e9 --- /dev/null +++ b/CLI/Web.cpp @@ -0,0 +1,106 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" +#include "luacode.h" + +#include "Luau/Common.h" + +#include + +#include + +static void setupState(lua_State* L) +{ + luaL_openlibs(L); + + luaL_sandbox(L); +} + +static std::string runCode(lua_State* L, const std::string& source) +{ + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.length(), nullptr, &bytecodeSize); + int result = luau_load(L, "=stdin", bytecode, bytecodeSize, 0); + free(bytecode); + + if (result != 0) + { + size_t len; + const char* msg = lua_tolstring(L, -1, &len); + + std::string error(msg, len); + lua_pop(L, 1); + + return error; + } + + lua_State* T = lua_newthread(L); + + lua_pushvalue(L, -2); + lua_remove(L, -3); + lua_xmove(L, T, 1); + + int status = lua_resume(T, NULL, 0); + + if (status == 0) + { + int n = lua_gettop(T); + + if (n) + { + luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); + lua_getglobal(T, "print"); + lua_insert(T, 1); + lua_pcall(T, n, 0, 0); + } + } + else + { + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + + error += "\nstack backtrace:\n"; + error += lua_debugtrace(T); + + error = "Error:" + error; + + fprintf(stdout, "%s", error.c_str()); + } + + lua_pop(L, 1); + return std::string(); +} + +extern "C" const char* executeScript(const char* source) +{ + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c69521ed..bafc59e59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) add_library(Luau.Ast STATIC) @@ -18,26 +19,22 @@ add_library(Luau.VM STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) - if(NOT EMSCRIPTEN) - add_executable(Luau.Analyze.CLI) - else() - # add -fexceptions for emscripten to allow exceptions to be caught in C++ - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") - endif() + add_executable(Luau.Analyze.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) - - if(NOT EMSCRIPTEN) - set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) - endif() + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) endif() -if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) +if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) endif() +if(LUAU_BUILD_WEB) + add_executable(Luau.Web) +endif() + include(Sources.cmake) target_compile_features(Luau.Ast PUBLIC cxx_std_17) @@ -72,16 +69,18 @@ if(LUAU_WERROR) endif() endif() +if(LUAU_BUILD_WEB) + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + list(APPEND LUAU_OPTIONS -fexceptions) +endif() + target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) - - if(NOT EMSCRIPTEN) - target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) - endif() + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) @@ -93,20 +92,10 @@ if(LUAU_BUILD_CLI) endif() endif() - if(NOT EMSCRIPTEN) - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) - endif() - - if(EMSCRIPTEN) - # declare exported functions to emscripten - target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions) - - # custom output directory for wasm + js file - set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau) - endif() + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) endif() -if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) +if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) @@ -115,3 +104,17 @@ if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) endif() + +if(LUAU_BUILD_WEB) + target_compile_options(Luau.Web PRIVATE ${LUAU_OPTIONS}) + target_link_libraries(Luau.Web PRIVATE Luau.Compiler Luau.VM) + + # declare exported functions to emscripten + target_link_options(Luau.Web PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap']) + + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + target_link_options(Luau.Web PRIVATE -fexceptions) + + # the output is a single .js file with an embedded wasm blob + target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1) +endif() diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 5b93c1dc0..2c1e85ff0 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -321,13 +321,15 @@ struct Compiler compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); } - setDebugLine(expr->func); + setDebugLineEnd(expr->func); if (expr->self) { AstExprIndexName* fi = expr->func->as(); LUAU_ASSERT(fi); + setDebugLine(fi->indexLocation); + BytecodeBuilder::StringRef iname = sref(fi->index); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) @@ -1313,6 +1315,8 @@ struct Compiler RegScope rs(this); uint8_t reg = compileExprAuto(expr->expr, rs); + setDebugLine(expr->indexLocation); + BytecodeBuilder::StringRef iname = sref(expr->index); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) @@ -2710,6 +2714,12 @@ struct Compiler bytecode.setDebugLine(node->location.begin.line + 1); } + void setDebugLine(const Location& location) + { + if (options.debugLevel >= 1) + bytecode.setDebugLine(location.begin.line + 1); + } + void setDebugLineEnd(AstNode* node) { if (options.debugLevel >= 1) @@ -3650,7 +3660,7 @@ struct Compiler { if (options.vectorLib) { - if (builtin.object == options.vectorLib && builtin.method == options.vectorCtor) + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) return LBF_VECTOR; } else diff --git a/Makefile b/Makefile index cab3d43f1..15c7ff7a4 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze -FUZZ_SOURCES=$(wildcard fuzz/*.cpp) +FUZZ_SOURCES=$(wildcard fuzz/*.cpp) fuzz/luau.pb.cpp FUZZ_OBJECTS=$(FUZZ_SOURCES:%=$(BUILD)/%.o) TESTS_ARGS= @@ -167,8 +167,8 @@ fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. mv fuzz/luau.pb.cc fuzz/luau.pb.cpp -$(BUILD)/fuzz/proto.cpp.o: build/libprotobuf-mutator -$(BUILD)/fuzz/protoprint.cpp.o: build/libprotobuf-mutator +$(BUILD)/fuzz/proto.cpp.o: fuzz/luau.pb.cpp +$(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator diff --git a/Sources.cmake b/Sources.cmake index 23b931c6b..57df9b91e 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -54,6 +54,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h + Analysis/include/Luau/ToDot.h Analysis/include/Luau/TopoSortStatements.h Analysis/include/Luau/ToString.h Analysis/include/Luau/Transpiler.h @@ -86,6 +87,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp + Analysis/src/ToDot.cpp Analysis/src/TopoSortStatements.cpp Analysis/src/ToString.cpp Analysis/src/Transpiler.cpp @@ -118,6 +120,7 @@ target_sources(Luau.VM PRIVATE VM/src/ldo.cpp VM/src/lfunc.cpp VM/src/lgc.cpp + VM/src/lgcdebug.cpp VM/src/linit.cpp VM/src/lmathlib.cpp VM/src/lmem.cpp @@ -194,6 +197,7 @@ if(TARGET Luau.UnitTest) tests/RequireTracer.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp + tests/ToDot.test.cpp tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp @@ -224,3 +228,9 @@ if(TARGET Luau.Conformance) tests/Conformance.test.cpp tests/main.cpp) endif() + +if(TARGET Luau.Web) + # Luau.Web Sources + target_sources(Luau.Web PRIVATE + CLI/Web.cpp) +endif() diff --git a/VM/include/lua.h b/VM/include/lua.h index 1568d191f..7078acd0f 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -21,6 +21,7 @@ #define LUA_ENVIRONINDEX (-10001) #define LUA_GLOBALSINDEX (-10002) #define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i)) +#define lua_ispseudo(i) ((i) <= LUA_REGISTRYINDEX) /* thread status; 0 is OK */ enum lua_Status @@ -108,6 +109,7 @@ LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation */ +LUA_API int lua_absindex(lua_State* L, int idx); LUA_API int lua_gettop(lua_State* L); LUA_API void lua_settop(lua_State* L, int idx); LUA_API void lua_pushvalue(lua_State* L, int idx); @@ -159,7 +161,11 @@ LUA_API void lua_pushnil(lua_State* L); LUA_API void lua_pushnumber(lua_State* L, double n); LUA_API void lua_pushinteger(lua_State* L, int n); LUA_API void lua_pushunsigned(lua_State* L, unsigned n); +#if LUA_VECTOR_SIZE == 4 +LUA_API void lua_pushvector(lua_State* L, float x, float y, float z, float w); +#else LUA_API void lua_pushvector(lua_State* L, float x, float y, float z); +#endif LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); @@ -183,7 +189,7 @@ LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); -LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API int lua_getmetatable(lua_State* L, int objindex); LUA_API void lua_getfenv(lua_State* L, int idx); @@ -227,6 +233,7 @@ enum lua_GCOp LUA_GCRESTART, LUA_GCCOLLECT, LUA_GCCOUNT, + LUA_GCCOUNTB, LUA_GCISRUNNING, // garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation @@ -281,6 +288,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_pop(L, n) lua_settop(L, -(n)-1) #define lua_newtable(L) lua_createtable(L, 0, 0) +#define lua_newuserdata(L, s) lua_newuserdatatagged(L, s, 0) #define lua_strlen(L, i) lua_objlen(L, (i)) @@ -289,6 +297,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_islightuserdata(L, n) (lua_type(L, (n)) == LUA_TLIGHTUSERDATA) #define lua_isnil(L, n) (lua_type(L, (n)) == LUA_TNIL) #define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN) +#define lua_isvector(L, n) (lua_type(L, (n)) == LUA_TVECTOR) #define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD) #define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE) #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index aa008a240..a01a14819 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -34,7 +34,10 @@ #endif /* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUA_API #define LUA_API extern +#endif + #define LUALIB_API LUA_API /* Can be used to reconfigure visibility for internal APIs */ @@ -47,10 +50,14 @@ #endif /* Can be used to reconfigure internal error handling to use longjmp instead of C++ EH */ +#ifndef LUA_USE_LONGJMP #define LUA_USE_LONGJMP 0 +#endif /* LUA_IDSIZE gives the maximum size for the description of the source */ +#ifndef LUA_IDSIZE #define LUA_IDSIZE 256 +#endif /* @@ LUAI_GCGOAL defines the desired top heap size in relation to the live heap @@ -59,7 +66,9 @@ ** mean larger GC pauses which mean slower collection.) You can also change ** this value dynamically. */ +#ifndef LUAI_GCGOAL #define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#endif /* @@ LUAI_GCSTEPMUL / LUAI_GCSTEPSIZE define the default speed of garbage collection @@ -69,38 +78,63 @@ ** CHANGE it if you want to change the granularity of the garbage ** collection. */ +#ifndef LUAI_GCSTEPMUL #define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ +#endif + +#ifndef LUAI_GCSTEPSIZE #define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#endif /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ +#ifndef LUA_MINSTACK #define LUA_MINSTACK 20 +#endif /* LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use */ +#ifndef LUAI_MAXCSTACK #define LUAI_MAXCSTACK 8000 +#endif /* LUAI_MAXCALLS limits the number of nested calls */ +#ifndef LUAI_MAXCALLS #define LUAI_MAXCALLS 20000 +#endif /* LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size */ +#ifndef LUAI_MAXCCALLS #define LUAI_MAXCCALLS 200 +#endif /* buffer size used for on-stack string operations; this limit depends on native stack size */ +#ifndef LUA_BUFFERSIZE #define LUA_BUFFERSIZE 512 +#endif /* number of valid Lua userdata tags */ +#ifndef LUA_UTAG_LIMIT #define LUA_UTAG_LIMIT 128 +#endif /* upper bound for number of size classes used by page allocator */ +#ifndef LUA_SIZECLASSES #define LUA_SIZECLASSES 32 +#endif /* available number of separate memory categories */ +#ifndef LUA_MEMORY_CATEGORIES #define LUA_MEMORY_CATEGORIES 256 +#endif /* minimum size for the string table (must be power of 2) */ +#ifndef LUA_MINSTRTABSIZE #define LUA_MINSTRTABSIZE 32 +#endif /* maximum number of captures supported by pattern matching */ +#ifndef LUA_MAXCAPTURES #define LUA_MAXCAPTURES 32 +#endif /* }================================================================== */ @@ -122,3 +156,7 @@ void* s; \ long l; \ } + +#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ + +#define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2 diff --git a/VM/include/lualib.h b/VM/include/lualib.h index fa836955c..baf27b47e 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -25,11 +25,17 @@ LUALIB_API const char* luaL_optlstring(lua_State* L, int numArg, const char* def LUALIB_API double luaL_checknumber(lua_State* L, int numArg); LUALIB_API double luaL_optnumber(lua_State* L, int nArg, double def); +LUALIB_API int luaL_checkboolean(lua_State* L, int narg); +LUALIB_API int luaL_optboolean(lua_State* L, int narg, int def); + LUALIB_API int luaL_checkinteger(lua_State* L, int numArg); LUALIB_API int luaL_optinteger(lua_State* L, int nArg, int def); LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int numArg); LUALIB_API unsigned luaL_optunsigned(lua_State* L, int numArg, unsigned def); +LUALIB_API const float* luaL_checkvector(lua_State* L, int narg); +LUALIB_API const float* luaL_optvector(lua_State* L, int narg, const float* def); + LUALIB_API void luaL_checkstack(lua_State* L, int sz, const char* msg); LUALIB_API void luaL_checktype(lua_State* L, int narg, int t); LUALIB_API void luaL_checkany(lua_State* L, int narg); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index a79ba0d40..76043b9cd 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAG(LuauActivateBeforeExec) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -170,6 +172,12 @@ lua_State* lua_mainthread(lua_State* L) ** basic stack manipulation */ +int lua_absindex(lua_State* L, int idx) +{ + api_check(L, (idx > 0 && idx <= L->top - L->base) || (idx < 0 && -idx <= L->top - L->base) || lua_ispseudo(idx)); + return idx > 0 || lua_ispseudo(idx) ? idx : cast_int(L->top - L->base) + idx + 1; +} + int lua_gettop(lua_State* L) { return cast_int(L->top - L->base); @@ -550,12 +558,21 @@ void lua_pushunsigned(lua_State* L, unsigned u) return; } +#if LUA_VECTOR_SIZE == 4 +void lua_pushvector(lua_State* L, float x, float y, float z, float w) +{ + setvvalue(L->top, x, y, z, w); + api_incr_top(L); + return; +} +#else void lua_pushvector(lua_State* L, float x, float y, float z) { - setvvalue(L->top, x, y, z); + setvvalue(L->top, x, y, z, 0.0f); api_incr_top(L); return; } +#endif void lua_pushlstring(lua_State* L, const char* s, size_t len) { @@ -922,14 +939,21 @@ void lua_call(lua_State* L, int nargs, int nresults) checkresults(L, nargs, nresults); func = L->top - (nargs + 1); - int wasActive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + if (FFlag::LuauActivateBeforeExec) + { + luaD_call(L, func, nresults); + } + else + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - luaD_call(L, func, nresults); + luaD_call(L, func, nresults); - if (!wasActive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } adjustresults(L, nresults); return; @@ -970,14 +994,21 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - int wasActive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + if (FFlag::LuauActivateBeforeExec) + { + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + } + else + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - if (!wasActive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } adjustresults(L, nresults); return status; @@ -1030,6 +1061,11 @@ int lua_gc(lua_State* L, int what, int data) res = cast_int(g->totalbytes >> 10); break; } + case LUA_GCCOUNTB: + { + res = cast_int(g->totalbytes & 1023); + break; + } case LUA_GCISRUNNING: { res = (g->GCthreshold != SIZE_MAX); @@ -1146,7 +1182,7 @@ void lua_concat(lua_State* L, int n) return; } -void* lua_newuserdata(lua_State* L, size_t sz, int tag) +void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); @@ -1231,6 +1267,7 @@ uintptr_t lua_encodepointer(lua_State* L, uintptr_t p) int lua_ref(lua_State* L, int idx) { + api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; StkId p = index2adr(L, idx); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 2a684ee4e..7ed2a62ee 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -30,7 +30,7 @@ static const char* currfuncname(lua_State* L) return debugname; } -LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) +l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) { const char* fname = currfuncname(L); @@ -40,7 +40,7 @@ LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) luaL_error(L, "invalid argument #%d (%s)", narg, extramsg); } -LUALIB_API l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) +l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) { const char* fname = currfuncname(L); const TValue* obj = luaA_toobject(L, narg); @@ -66,7 +66,7 @@ static l_noret tag_error(lua_State* L, int narg, int tag) luaL_typeerrorL(L, narg, lua_typename(L, tag)); } -LUALIB_API void luaL_where(lua_State* L, int level) +void luaL_where(lua_State* L, int level) { lua_Debug ar; if (lua_getinfo(L, level, "sl", &ar) && ar.currentline > 0) @@ -77,7 +77,7 @@ LUALIB_API void luaL_where(lua_State* L, int level) lua_pushliteral(L, ""); /* else, no information available... */ } -LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) +l_noret luaL_errorL(lua_State* L, const char* fmt, ...) { va_list argp; va_start(argp, fmt); @@ -90,7 +90,7 @@ LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) /* }====================================================== */ -LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) +int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) { const char* name = (def) ? luaL_optstring(L, narg, def) : luaL_checkstring(L, narg); int i; @@ -101,7 +101,7 @@ LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const c luaL_argerrorL(L, narg, msg); } -LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) +int luaL_newmetatable(lua_State* L, const char* tname) { lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get registry.name */ if (!lua_isnil(L, -1)) /* name already in use? */ @@ -113,7 +113,7 @@ LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) return 1; } -LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) +void* luaL_checkudata(lua_State* L, int ud, const char* tname) { void* p = lua_touserdata(L, ud); if (p != NULL) @@ -131,25 +131,25 @@ LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) luaL_typeerrorL(L, ud, tname); /* else error */ } -LUALIB_API void luaL_checkstack(lua_State* L, int space, const char* mes) +void luaL_checkstack(lua_State* L, int space, const char* mes) { if (!lua_checkstack(L, space)) luaL_error(L, "stack overflow (%s)", mes); } -LUALIB_API void luaL_checktype(lua_State* L, int narg, int t) +void luaL_checktype(lua_State* L, int narg, int t) { if (lua_type(L, narg) != t) tag_error(L, narg, t); } -LUALIB_API void luaL_checkany(lua_State* L, int narg) +void luaL_checkany(lua_State* L, int narg) { if (lua_type(L, narg) == LUA_TNONE) luaL_error(L, "missing argument #%d", narg); } -LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) +const char* luaL_checklstring(lua_State* L, int narg, size_t* len) { const char* s = lua_tolstring(L, narg, len); if (!s) @@ -157,7 +157,7 @@ LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) return s; } -LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) +const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) { if (lua_isnoneornil(L, narg)) { @@ -169,7 +169,7 @@ LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, return luaL_checklstring(L, narg, len); } -LUALIB_API double luaL_checknumber(lua_State* L, int narg) +double luaL_checknumber(lua_State* L, int narg) { int isnum; double d = lua_tonumberx(L, narg, &isnum); @@ -178,12 +178,28 @@ LUALIB_API double luaL_checknumber(lua_State* L, int narg) return d; } -LUALIB_API double luaL_optnumber(lua_State* L, int narg, double def) +double luaL_optnumber(lua_State* L, int narg, double def) { return luaL_opt(L, luaL_checknumber, narg, def); } -LUALIB_API int luaL_checkinteger(lua_State* L, int narg) +int luaL_checkboolean(lua_State* L, int narg) +{ + // This checks specifically for boolean values, ignoring + // all other truthy/falsy values. If the desired result + // is true if value is present then lua_toboolean should + // directly be used instead. + if (!lua_isboolean(L, narg)) + tag_error(L, narg, LUA_TBOOLEAN); + return lua_toboolean(L, narg); +} + +int luaL_optboolean(lua_State* L, int narg, int def) +{ + return luaL_opt(L, luaL_checkboolean, narg, def); +} + +int luaL_checkinteger(lua_State* L, int narg) { int isnum; int d = lua_tointegerx(L, narg, &isnum); @@ -192,12 +208,12 @@ LUALIB_API int luaL_checkinteger(lua_State* L, int narg) return d; } -LUALIB_API int luaL_optinteger(lua_State* L, int narg, int def) +int luaL_optinteger(lua_State* L, int narg, int def) { return luaL_opt(L, luaL_checkinteger, narg, def); } -LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) +unsigned luaL_checkunsigned(lua_State* L, int narg) { int isnum; unsigned d = lua_tounsignedx(L, narg, &isnum); @@ -206,12 +222,25 @@ LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) return d; } -LUALIB_API unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) +unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) { return luaL_opt(L, luaL_checkunsigned, narg, def); } -LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) +const float* luaL_checkvector(lua_State* L, int narg) +{ + const float* v = lua_tovector(L, narg); + if (!v) + tag_error(L, narg, LUA_TVECTOR); + return v; +} + +const float* luaL_optvector(lua_State* L, int narg, const float* def) +{ + return luaL_opt(L, luaL_checkvector, narg, def); +} + +int luaL_getmetafield(lua_State* L, int obj, const char* event) { if (!lua_getmetatable(L, obj)) /* no metatable? */ return 0; @@ -229,7 +258,7 @@ LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) } } -LUALIB_API int luaL_callmeta(lua_State* L, int obj, const char* event) +int luaL_callmeta(lua_State* L, int obj, const char* event) { obj = abs_index(L, obj); if (!luaL_getmetafield(L, obj, event)) /* no metafield? */ @@ -247,7 +276,7 @@ static int libsize(const luaL_Reg* l) return size; } -LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) +void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) { if (libname) { @@ -273,7 +302,7 @@ LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* } } -LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) +const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) { const char* e; lua_pushvalue(L, idx); @@ -324,7 +353,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired return newsize; } -LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) +void luaL_buffinit(lua_State* L, luaL_Buffer* B) { // start with an internal buffer B->p = B->buffer; @@ -334,14 +363,14 @@ LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) B->storage = nullptr; } -LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) +char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) { luaL_buffinit(L, B); luaL_reservebuffer(B, size, -1); return B->p; } -LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) +char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) { lua_State* L = B->L; @@ -372,13 +401,13 @@ LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int bo return B->p; } -LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) +void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) { if (size_t(B->end - B->p) < size) luaL_extendbuffer(B, size - (B->end - B->p), boxloc); } -LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) +void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) { if (size_t(B->end - B->p) < len) luaL_extendbuffer(B, len - (B->end - B->p), -1); @@ -387,7 +416,7 @@ LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) B->p += len; } -LUALIB_API void luaL_addvalue(luaL_Buffer* B) +void luaL_addvalue(luaL_Buffer* B) { lua_State* L = B->L; @@ -404,7 +433,7 @@ LUALIB_API void luaL_addvalue(luaL_Buffer* B) } } -LUALIB_API void luaL_pushresult(luaL_Buffer* B) +void luaL_pushresult(luaL_Buffer* B) { lua_State* L = B->L; @@ -428,7 +457,7 @@ LUALIB_API void luaL_pushresult(luaL_Buffer* B) } } -LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) +void luaL_pushresultsize(luaL_Buffer* B, size_t size) { B->p += size; luaL_pushresult(B); @@ -436,7 +465,7 @@ LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) /* }====================================================== */ -LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) +const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { if (luaL_callmeta(L, idx, "__tostring")) /* is there a metafield? */ { @@ -462,7 +491,11 @@ LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); +#if LUA_VECTOR_SIZE == 4 + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); +#else lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); +#endif break; } default: diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 61798e2bc..881c804db 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -401,7 +401,7 @@ static int luaB_newproxy(lua_State* L) bool needsmt = lua_toboolean(L, 1); - lua_newuserdata(L, 0, 0); + lua_newuserdata(L, 0); if (needsmt) { @@ -441,7 +441,7 @@ static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFuncti lua_setfield(L, -2, name); } -LUALIB_API int luaopen_base(lua_State* L) +int luaopen_base(lua_State* L) { /* set global _G */ lua_pushvalue(L, LUA_GLOBALSINDEX); diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 907c43c42..8b511edf0 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -236,7 +236,7 @@ static const luaL_Reg bitlib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_bit32(lua_State* L) +int luaopen_bit32(lua_State* L) { luaL_register(L, LUA_BITLIBNAME, bitlib); diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 9ab57ac9b..34e9ebc1f 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1018,13 +1018,23 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { +#if LUA_VECTOR_SIZE == 4 + if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) +#else if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) +#endif { double x = nvalue(arg0); double y = nvalue(args); double z = nvalue(args + 1); - setvvalue(res, float(x), float(y), float(z)); +#if LUA_VECTOR_SIZE == 4 + double w = nvalue(args + 2); + setvvalue(res, float(x), float(y), float(z), float(w)); +#else + setvvalue(res, float(x), float(y), float(z), 0.0f); +#endif + return 1; } diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 0178fae84..abcde7796 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -272,7 +272,7 @@ static const luaL_Reg co_funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_coroutine(lua_State* L) +int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index 965d2b3de..93d8703a6 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -160,7 +160,7 @@ static const luaL_Reg dblib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_debug(lua_State* L) +int luaopen_debug(lua_State* L) { luaL_register(L, LUA_DBLIBNAME, dblib); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 1259d4619..62bbdb7c9 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,9 +17,9 @@ #include -LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) +LUAU_FASTFLAGVARIABLE(LuauActivateBeforeExec, false) /* ** {====================================================== @@ -74,35 +74,28 @@ class lua_exception : public std::exception const char* what() const throw() override { - if (FFlag::LuauExceptionMessageFix) + // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. + if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) + // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + if (const char* str = lua_tostring(L, -1)) { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. - if (const char* str = lua_tostring(L, -1)) - { - return str; - } - } - - switch (status) - { - case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; - case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; - case LUA_ERRMEM: - return "lua_exception: " LUA_MEMERRMSG; - case LUA_ERRERR: - return "lua_exception: " LUA_ERRERRMSG; - default: - return "lua_exception: unexpected exception status"; + return str; } } - else + + switch (status) { - return lua_tostring(L, -1); + case LUA_ERRRUN: + return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + case LUA_ERRSYNTAX: + return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + case LUA_ERRMEM: + return "lua_exception: " LUA_MEMERRMSG; + case LUA_ERRERR: + return "lua_exception: " LUA_ERRERRMSG; + default: + return "lua_exception: unexpected exception status"; } } @@ -234,7 +227,22 @@ void luaD_call(lua_State* L, StkId func, int nResults) if (luau_precall(L, func, nResults) == PCRLUA) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ - luau_execute(L); /* call it */ + + if (FFlag::LuauActivateBeforeExec) + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); + + luau_execute(L); /* call it */ + + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } + else + { + luau_execute(L); /* call it */ + } } L->nCcalls--; luaC_checkGC(L); @@ -527,10 +535,10 @@ static void restore_stack_limit(lua_State* L) int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t ef) { - int status; unsigned short oldnCcalls = L->nCcalls; ptrdiff_t old_ci = saveci(L, L->ci); - status = luaD_rawrunprotected(L, func, u); + int oldactive = luaC_threadactive(L); + int status = luaD_rawrunprotected(L, func, u); if (status != 0) { // call user-defined error function (used in xpcall) @@ -541,6 +549,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } + if (FFlag::LuauActivateBeforeExec) + { + // since the call failed with an error, we might have to reset the 'active' thread state + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } + if (FFlag::LuauCcallRestoreFix) { // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 11f79d1a3..ab416041e 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -10,9 +10,7 @@ #include "ldo.h" #include -#include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -988,7 +986,7 @@ void luaC_barriertable(lua_State* L, Table* t, GCObject* v) GCObject* o = obj2gco(t); // in the second propagation stage, table assignment barrier works as a forward barrier - if (FFlag::LuauRescanGrayAgainForwardBarrier && g->gcstate == GCSpropagateagain) + if (g->gcstate == GCSpropagateagain) { LUAU_ASSERT(isblack(o) && iswhite(v) && !isdead(g, v) && !isdead(g, o)); reallymarkobject(g, v); @@ -1044,550 +1042,6 @@ void luaC_linkupval(lua_State* L, UpVal* uv) } } -static void validateobjref(global_State* g, GCObject* f, GCObject* t) -{ - LUAU_ASSERT(!isdead(g, t)); - - if (keepinvariant(g)) - { - /* basic incremental invariant: black can't point to white */ - LUAU_ASSERT(!(isblack(f) && iswhite(t))); - } -} - -static void validateref(global_State* g, GCObject* f, TValue* v) -{ - if (iscollectable(v)) - { - LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); - validateobjref(g, f, gcvalue(v)); - } -} - -static void validatetable(global_State* g, Table* h) -{ - int sizenode = 1 << h->lsizenode; - - if (FFlag::LuauArrayBoundary) - LUAU_ASSERT(h->lastfree <= sizenode); - else - LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); - - if (h->metatable) - validateobjref(g, obj2gco(h), obj2gco(h->metatable)); - - for (int i = 0; i < h->sizearray; ++i) - validateref(g, obj2gco(h), &h->array[i]); - - for (int i = 0; i < sizenode; ++i) - { - LuaNode* n = &h->node[i]; - - LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); - LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); - - if (!ttisnil(gval(n))) - { - TValue k = {}; - k.tt = gkey(n)->tt; - k.value = gkey(n)->value; - - validateref(g, obj2gco(h), &k); - validateref(g, obj2gco(h), gval(n)); - } - } -} - -static void validateclosure(global_State* g, Closure* cl) -{ - validateobjref(g, obj2gco(cl), obj2gco(cl->env)); - - if (cl->isC) - { - for (int i = 0; i < cl->nupvalues; ++i) - validateref(g, obj2gco(cl), &cl->c.upvals[i]); - } - else - { - LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); - - validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); - - for (int i = 0; i < cl->nupvalues; ++i) - validateref(g, obj2gco(cl), &cl->l.uprefs[i]); - } -} - -static void validatestack(global_State* g, lua_State* l) -{ - validateref(g, obj2gco(l), gt(l)); - - for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) - { - LUAU_ASSERT(l->stack <= ci->base); - LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); - LUAU_ASSERT(ci->top <= l->stack_last); - } - - // note: stack refs can violate gc invariant so we only check for liveness - for (StkId o = l->stack; o < l->top; ++o) - checkliveness(g, o); - - if (l->namecall) - validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - - for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) - { - LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); - LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); - } -} - -static void validateproto(global_State* g, Proto* f) -{ - if (f->source) - validateobjref(g, obj2gco(f), obj2gco(f->source)); - - if (f->debugname) - validateobjref(g, obj2gco(f), obj2gco(f->debugname)); - - for (int i = 0; i < f->sizek; ++i) - validateref(g, obj2gco(f), &f->k[i]); - - for (int i = 0; i < f->sizeupvalues; ++i) - if (f->upvalues[i]) - validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); - - for (int i = 0; i < f->sizep; ++i) - if (f->p[i]) - validateobjref(g, obj2gco(f), obj2gco(f->p[i])); - - for (int i = 0; i < f->sizelocvars; i++) - if (f->locvars[i].varname) - validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); -} - -static void validateobj(global_State* g, GCObject* o) -{ - /* dead objects can only occur during sweep */ - if (isdead(g, o)) - { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - return; - } - - switch (o->gch.tt) - { - case LUA_TSTRING: - break; - - case LUA_TTABLE: - validatetable(g, gco2h(o)); - break; - - case LUA_TFUNCTION: - validateclosure(g, gco2cl(o)); - break; - - case LUA_TUSERDATA: - if (gco2u(o)->metatable) - validateobjref(g, o, obj2gco(gco2u(o)->metatable)); - break; - - case LUA_TTHREAD: - validatestack(g, gco2th(o)); - break; - - case LUA_TPROTO: - validateproto(g, gco2p(o)); - break; - - case LUA_TUPVAL: - validateref(g, o, gco2uv(o)->v); - break; - - default: - LUAU_ASSERT(!"unexpected object type"); - } -} - -static void validatelist(global_State* g, GCObject* o) -{ - while (o) - { - validateobj(g, o); - - o = o->gch.next; - } -} - -static void validategraylist(global_State* g, GCObject* o) -{ - if (!keepinvariant(g)) - return; - - while (o) - { - LUAU_ASSERT(isgray(o)); - - switch (o->gch.tt) - { - case LUA_TTABLE: - o = gco2h(o)->gclist; - break; - case LUA_TFUNCTION: - o = gco2cl(o)->gclist; - break; - case LUA_TTHREAD: - o = gco2th(o)->gclist; - break; - case LUA_TPROTO: - o = gco2p(o)->gclist; - break; - default: - LUAU_ASSERT(!"unknown object in gray list"); - return; - } - } -} - -void luaC_validate(lua_State* L) -{ - global_State* g = L->global; - - LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); - checkliveness(g, &g->registry); - - for (int i = 0; i < LUA_T_COUNT; ++i) - if (g->mt[i]) - LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); - - validategraylist(g, g->weak); - validategraylist(g, g->gray); - validategraylist(g, g->grayagain); - - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, g->strt.hash[i]); - - validatelist(g, g->rootgc); - validatelist(g, g->strbufgc); - - for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) - { - LUAU_ASSERT(uv->tt == LUA_TUPVAL); - LUAU_ASSERT(uv->v != &uv->u.value); - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); - } -} - -inline bool safejson(char ch) -{ - return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; -} - -static void dumpref(FILE* f, GCObject* o) -{ - fprintf(f, "\"%p\"", o); -} - -static void dumprefs(FILE* f, TValue* data, size_t size) -{ - bool first = true; - - for (size_t i = 0; i < size; ++i) - { - if (iscollectable(&data[i])) - { - if (!first) - fputc(',', f); - first = false; - - dumpref(f, gcvalue(&data[i])); - } - } -} - -static void dumpstringdata(FILE* f, const char* data, size_t len) -{ - for (size_t i = 0; i < len; ++i) - fputc(safejson(data[i]) ? data[i] : '?', f); -} - -static void dumpstring(FILE* f, TString* ts) -{ - fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); - dumpstringdata(f, ts->data, ts->len); - fprintf(f, "\"}"); -} - -static void dumptable(FILE* f, Table* h) -{ - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); - - fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); - - if (h->node != &luaH_dummynode) - { - fprintf(f, ",\"pairs\":["); - - bool first = true; - - for (int i = 0; i < sizenode(h); ++i) - { - const LuaNode& n = h->node[i]; - - if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) - { - if (!first) - fputc(',', f); - first = false; - - if (iscollectable(&n.key)) - dumpref(f, gcvalue(&n.key)); - else - fprintf(f, "null"); - - fputc(',', f); - - if (iscollectable(&n.val)) - dumpref(f, gcvalue(&n.val)); - else - fprintf(f, "null"); - } - } - - fprintf(f, "]"); - } - if (h->sizearray) - { - fprintf(f, ",\"array\":["); - dumprefs(f, h->array, h->sizearray); - fprintf(f, "]"); - } - if (h->metatable) - { - fprintf(f, ",\"metatable\":"); - dumpref(f, obj2gco(h->metatable)); - } - fprintf(f, "}"); -} - -static void dumpclosure(FILE* f, Closure* cl) -{ - fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, - cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); - - fprintf(f, ",\"env\":"); - dumpref(f, obj2gco(cl->env)); - if (cl->isC) - { - if (cl->nupvalues) - { - fprintf(f, ",\"upvalues\":["); - dumprefs(f, cl->c.upvals, cl->nupvalues); - fprintf(f, "]"); - } - } - else - { - fprintf(f, ",\"proto\":"); - dumpref(f, obj2gco(cl->l.p)); - if (cl->nupvalues) - { - fprintf(f, ",\"upvalues\":["); - dumprefs(f, cl->l.uprefs, cl->nupvalues); - fprintf(f, "]"); - } - } - fprintf(f, "}"); -} - -static void dumpudata(FILE* f, Udata* u) -{ - fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); - - if (u->metatable) - { - fprintf(f, ",\"metatable\":"); - dumpref(f, obj2gco(u->metatable)); - } - fprintf(f, "}"); -} - -static void dumpthread(FILE* f, lua_State* th) -{ - size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; - - fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); - - if (iscollectable(&th->l_gt)) - { - fprintf(f, ",\"env\":"); - dumpref(f, gcvalue(&th->l_gt)); - } - - Closure* tcl = 0; - for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) - { - if (ttisfunction(ci->func)) - { - tcl = clvalue(ci->func); - break; - } - } - - if (tcl && !tcl->isC && tcl->l.p->source) - { - Proto* p = tcl->l.p; - - fprintf(f, ",\"source\":\""); - dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); - } - - if (th->top > th->stack) - { - fprintf(f, ",\"stack\":["); - dumprefs(f, th->stack, th->top - th->stack); - fprintf(f, "]"); - } - fprintf(f, "}"); -} - -static void dumpproto(FILE* f, Proto* p) -{ - size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; - - fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); - - if (p->source) - { - fprintf(f, ",\"source\":\""); - dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); - } - - if (p->sizek) - { - fprintf(f, ",\"constants\":["); - dumprefs(f, p->k, p->sizek); - fprintf(f, "]"); - } - - if (p->sizep) - { - fprintf(f, ",\"protos\":["); - for (int i = 0; i < p->sizep; ++i) - { - if (i != 0) - fputc(',', f); - dumpref(f, obj2gco(p->p[i])); - } - fprintf(f, "]"); - } - - fprintf(f, "}"); -} - -static void dumpupval(FILE* f, UpVal* uv) -{ - fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); - - if (iscollectable(uv->v)) - { - fprintf(f, ",\"object\":"); - dumpref(f, gcvalue(uv->v)); - } - fprintf(f, "}"); -} - -static void dumpobj(FILE* f, GCObject* o) -{ - switch (o->gch.tt) - { - case LUA_TSTRING: - return dumpstring(f, gco2ts(o)); - - case LUA_TTABLE: - return dumptable(f, gco2h(o)); - - case LUA_TFUNCTION: - return dumpclosure(f, gco2cl(o)); - - case LUA_TUSERDATA: - return dumpudata(f, gco2u(o)); - - case LUA_TTHREAD: - return dumpthread(f, gco2th(o)); - - case LUA_TPROTO: - return dumpproto(f, gco2p(o)); - - case LUA_TUPVAL: - return dumpupval(f, gco2uv(o)); - - default: - LUAU_ASSERT(0); - } -} - -static void dumplist(FILE* f, GCObject* o) -{ - while (o) - { - dumpref(f, o); - fputc(':', f); - dumpobj(f, o); - fputc(',', f); - fputc('\n', f); - - // thread has additional list containing collectable objects that are not present in rootgc - if (o->gch.tt == LUA_TTHREAD) - dumplist(f, gco2th(o)->openupval); - - o = o->gch.next; - } -} - -void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) -{ - global_State* g = L->global; - FILE* f = static_cast(file); - - fprintf(f, "{\"objects\":{\n"); - dumplist(f, g->rootgc); - dumplist(f, g->strbufgc); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, g->strt.hash[i]); - - fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , - fprintf(f, "},\"roots\":{\n"); - fprintf(f, "\"mainthread\":"); - dumpref(f, obj2gco(g->mainthread)); - fprintf(f, ",\"registry\":"); - dumpref(f, gcvalue(&g->registry)); - - fprintf(f, "},\"stats\":{\n"); - - fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); - - fprintf(f, "\"categories\":{\n"); - for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) - { - if (size_t bytes = g->memcatbytes[i]) - { - if (categoryName) - fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); - else - fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); - } - } - fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , - fprintf(f, "}\n"); - fprintf(f, "}}\n"); -} - // measure the allocation rate in bytes/sec // returns -1 if allocation rate cannot be measured int64_t luaC_allocationrate(lua_State* L) diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp new file mode 100644 index 000000000..a79e7b953 --- /dev/null +++ b/VM/src/lgcdebug.cpp @@ -0,0 +1,558 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lgc.h" + +#include "lobject.h" +#include "lstate.h" +#include "ltable.h" +#include "lfunc.h" +#include "lstring.h" + +#include +#include + +LUAU_FASTFLAG(LuauArrayBoundary) + +static void validateobjref(global_State* g, GCObject* f, GCObject* t) +{ + LUAU_ASSERT(!isdead(g, t)); + + if (keepinvariant(g)) + { + /* basic incremental invariant: black can't point to white */ + LUAU_ASSERT(!(isblack(f) && iswhite(t))); + } +} + +static void validateref(global_State* g, GCObject* f, TValue* v) +{ + if (iscollectable(v)) + { + LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); + validateobjref(g, f, gcvalue(v)); + } +} + +static void validatetable(global_State* g, Table* h) +{ + int sizenode = 1 << h->lsizenode; + + if (FFlag::LuauArrayBoundary) + LUAU_ASSERT(h->lastfree <= sizenode); + else + LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + + if (h->metatable) + validateobjref(g, obj2gco(h), obj2gco(h->metatable)); + + for (int i = 0; i < h->sizearray; ++i) + validateref(g, obj2gco(h), &h->array[i]); + + for (int i = 0; i < sizenode; ++i) + { + LuaNode* n = &h->node[i]; + + LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); + LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); + + if (!ttisnil(gval(n))) + { + TValue k = {}; + k.tt = gkey(n)->tt; + k.value = gkey(n)->value; + + validateref(g, obj2gco(h), &k); + validateref(g, obj2gco(h), gval(n)); + } + } +} + +static void validateclosure(global_State* g, Closure* cl) +{ + validateobjref(g, obj2gco(cl), obj2gco(cl->env)); + + if (cl->isC) + { + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->c.upvals[i]); + } + else + { + LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); + + validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); + + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->l.uprefs[i]); + } +} + +static void validatestack(global_State* g, lua_State* l) +{ + validateref(g, obj2gco(l), gt(l)); + + for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) + { + LUAU_ASSERT(l->stack <= ci->base); + LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); + LUAU_ASSERT(ci->top <= l->stack_last); + } + + // note: stack refs can violate gc invariant so we only check for liveness + for (StkId o = l->stack; o < l->top; ++o) + checkliveness(g, o); + + if (l->namecall) + validateobjref(g, obj2gco(l), obj2gco(l->namecall)); + + for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) + { + LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); + LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); + } +} + +static void validateproto(global_State* g, Proto* f) +{ + if (f->source) + validateobjref(g, obj2gco(f), obj2gco(f->source)); + + if (f->debugname) + validateobjref(g, obj2gco(f), obj2gco(f->debugname)); + + for (int i = 0; i < f->sizek; ++i) + validateref(g, obj2gco(f), &f->k[i]); + + for (int i = 0; i < f->sizeupvalues; ++i) + if (f->upvalues[i]) + validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); + + for (int i = 0; i < f->sizep; ++i) + if (f->p[i]) + validateobjref(g, obj2gco(f), obj2gco(f->p[i])); + + for (int i = 0; i < f->sizelocvars; i++) + if (f->locvars[i].varname) + validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); +} + +static void validateobj(global_State* g, GCObject* o) +{ + /* dead objects can only occur during sweep */ + if (isdead(g, o)) + { + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + return; + } + + switch (o->gch.tt) + { + case LUA_TSTRING: + break; + + case LUA_TTABLE: + validatetable(g, gco2h(o)); + break; + + case LUA_TFUNCTION: + validateclosure(g, gco2cl(o)); + break; + + case LUA_TUSERDATA: + if (gco2u(o)->metatable) + validateobjref(g, o, obj2gco(gco2u(o)->metatable)); + break; + + case LUA_TTHREAD: + validatestack(g, gco2th(o)); + break; + + case LUA_TPROTO: + validateproto(g, gco2p(o)); + break; + + case LUA_TUPVAL: + validateref(g, o, gco2uv(o)->v); + break; + + default: + LUAU_ASSERT(!"unexpected object type"); + } +} + +static void validatelist(global_State* g, GCObject* o) +{ + while (o) + { + validateobj(g, o); + + o = o->gch.next; + } +} + +static void validategraylist(global_State* g, GCObject* o) +{ + if (!keepinvariant(g)) + return; + + while (o) + { + LUAU_ASSERT(isgray(o)); + + switch (o->gch.tt) + { + case LUA_TTABLE: + o = gco2h(o)->gclist; + break; + case LUA_TFUNCTION: + o = gco2cl(o)->gclist; + break; + case LUA_TTHREAD: + o = gco2th(o)->gclist; + break; + case LUA_TPROTO: + o = gco2p(o)->gclist; + break; + default: + LUAU_ASSERT(!"unknown object in gray list"); + return; + } + } +} + +void luaC_validate(lua_State* L) +{ + global_State* g = L->global; + + LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); + checkliveness(g, &g->registry); + + for (int i = 0; i < LUA_T_COUNT; ++i) + if (g->mt[i]) + LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); + + validategraylist(g, g->weak); + validategraylist(g, g->gray); + validategraylist(g, g->grayagain); + + for (int i = 0; i < g->strt.size; ++i) + validatelist(g, g->strt.hash[i]); + + validatelist(g, g->rootgc); + validatelist(g, g->strbufgc); + + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + { + LUAU_ASSERT(uv->tt == LUA_TUPVAL); + LUAU_ASSERT(uv->v != &uv->u.value); + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + } +} + +inline bool safejson(char ch) +{ + return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; +} + +static void dumpref(FILE* f, GCObject* o) +{ + fprintf(f, "\"%p\"", o); +} + +static void dumprefs(FILE* f, TValue* data, size_t size) +{ + bool first = true; + + for (size_t i = 0; i < size; ++i) + { + if (iscollectable(&data[i])) + { + if (!first) + fputc(',', f); + first = false; + + dumpref(f, gcvalue(&data[i])); + } + } +} + +static void dumpstringdata(FILE* f, const char* data, size_t len) +{ + for (size_t i = 0; i < len; ++i) + fputc(safejson(data[i]) ? data[i] : '?', f); +} + +static void dumpstring(FILE* f, TString* ts) +{ + fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); + dumpstringdata(f, ts->data, ts->len); + fprintf(f, "\"}"); +} + +static void dumptable(FILE* f, Table* h) +{ + size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + + fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); + + if (h->node != &luaH_dummynode) + { + fprintf(f, ",\"pairs\":["); + + bool first = true; + + for (int i = 0; i < sizenode(h); ++i) + { + const LuaNode& n = h->node[i]; + + if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) + { + if (!first) + fputc(',', f); + first = false; + + if (iscollectable(&n.key)) + dumpref(f, gcvalue(&n.key)); + else + fprintf(f, "null"); + + fputc(',', f); + + if (iscollectable(&n.val)) + dumpref(f, gcvalue(&n.val)); + else + fprintf(f, "null"); + } + } + + fprintf(f, "]"); + } + if (h->sizearray) + { + fprintf(f, ",\"array\":["); + dumprefs(f, h->array, h->sizearray); + fprintf(f, "]"); + } + if (h->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(h->metatable)); + } + fprintf(f, "}"); +} + +static void dumpclosure(FILE* f, Closure* cl) +{ + fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, + cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); + + fprintf(f, ",\"env\":"); + dumpref(f, obj2gco(cl->env)); + if (cl->isC) + { + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->c.upvals, cl->nupvalues); + fprintf(f, "]"); + } + } + else + { + fprintf(f, ",\"proto\":"); + dumpref(f, obj2gco(cl->l.p)); + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->l.uprefs, cl->nupvalues); + fprintf(f, "]"); + } + } + fprintf(f, "}"); +} + +static void dumpudata(FILE* f, Udata* u) +{ + fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); + + if (u->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(u->metatable)); + } + fprintf(f, "}"); +} + +static void dumpthread(FILE* f, lua_State* th) +{ + size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; + + fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); + + if (iscollectable(&th->l_gt)) + { + fprintf(f, ",\"env\":"); + dumpref(f, gcvalue(&th->l_gt)); + } + + Closure* tcl = 0; + for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) + { + if (ttisfunction(ci->func)) + { + tcl = clvalue(ci->func); + break; + } + } + + if (tcl && !tcl->isC && tcl->l.p->source) + { + Proto* p = tcl->l.p; + + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (th->top > th->stack) + { + fprintf(f, ",\"stack\":["); + dumprefs(f, th->stack, th->top - th->stack); + fprintf(f, "]"); + } + fprintf(f, "}"); +} + +static void dumpproto(FILE* f, Proto* p) +{ + size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; + + fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); + + if (p->source) + { + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (p->sizek) + { + fprintf(f, ",\"constants\":["); + dumprefs(f, p->k, p->sizek); + fprintf(f, "]"); + } + + if (p->sizep) + { + fprintf(f, ",\"protos\":["); + for (int i = 0; i < p->sizep; ++i) + { + if (i != 0) + fputc(',', f); + dumpref(f, obj2gco(p->p[i])); + } + fprintf(f, "]"); + } + + fprintf(f, "}"); +} + +static void dumpupval(FILE* f, UpVal* uv) +{ + fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); + + if (iscollectable(uv->v)) + { + fprintf(f, ",\"object\":"); + dumpref(f, gcvalue(uv->v)); + } + fprintf(f, "}"); +} + +static void dumpobj(FILE* f, GCObject* o) +{ + switch (o->gch.tt) + { + case LUA_TSTRING: + return dumpstring(f, gco2ts(o)); + + case LUA_TTABLE: + return dumptable(f, gco2h(o)); + + case LUA_TFUNCTION: + return dumpclosure(f, gco2cl(o)); + + case LUA_TUSERDATA: + return dumpudata(f, gco2u(o)); + + case LUA_TTHREAD: + return dumpthread(f, gco2th(o)); + + case LUA_TPROTO: + return dumpproto(f, gco2p(o)); + + case LUA_TUPVAL: + return dumpupval(f, gco2uv(o)); + + default: + LUAU_ASSERT(0); + } +} + +static void dumplist(FILE* f, GCObject* o) +{ + while (o) + { + dumpref(f, o); + fputc(':', f); + dumpobj(f, o); + fputc(',', f); + fputc('\n', f); + + // thread has additional list containing collectable objects that are not present in rootgc + if (o->gch.tt == LUA_TTHREAD) + dumplist(f, gco2th(o)->openupval); + + o = o->gch.next; + } +} + +void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) +{ + global_State* g = L->global; + FILE* f = static_cast(file); + + fprintf(f, "{\"objects\":{\n"); + dumplist(f, g->rootgc); + dumplist(f, g->strbufgc); + for (int i = 0; i < g->strt.size; ++i) + dumplist(f, g->strt.hash[i]); + + fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , + fprintf(f, "},\"roots\":{\n"); + fprintf(f, "\"mainthread\":"); + dumpref(f, obj2gco(g->mainthread)); + fprintf(f, ",\"registry\":"); + dumpref(f, gcvalue(&g->registry)); + + fprintf(f, "},\"stats\":{\n"); + + fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); + + fprintf(f, "\"categories\":{\n"); + for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) + { + if (size_t bytes = g->memcatbytes[i]) + { + if (categoryName) + fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); + else + fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); + } + } + fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , + fprintf(f, "}\n"); + fprintf(f, "}}\n"); +} diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index 4e40165ab..c93f431f1 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -17,7 +17,7 @@ static const luaL_Reg lualibs[] = { {NULL, NULL}, }; -LUALIB_API void luaL_openlibs(lua_State* L) +void luaL_openlibs(lua_State* L) { const luaL_Reg* lib = lualibs; for (; lib->func; lib++) @@ -28,7 +28,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) } } -LUALIB_API void luaL_sandbox(lua_State* L) +void luaL_sandbox(lua_State* L) { // set all libraries to read-only lua_pushnil(L); @@ -44,14 +44,14 @@ LUALIB_API void luaL_sandbox(lua_State* L) lua_pushliteral(L, ""); lua_getmetatable(L, -1); lua_setreadonly(L, -1, true); - lua_pop(L, 1); + lua_pop(L, 2); // set globals to readonly and activate safeenv since the env is immutable lua_setreadonly(L, LUA_GLOBALSINDEX, true); lua_setsafeenv(L, LUA_GLOBALSINDEX, true); } -LUALIB_API void luaL_sandboxthread(lua_State* L) +void luaL_sandboxthread(lua_State* L) { // create new global table that proxies reads to original table lua_newtable(L); @@ -81,7 +81,7 @@ static void* l_alloc(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsi return realloc(ptr, nsize); } -LUALIB_API lua_State* luaL_newstate(void) +lua_State* luaL_newstate(void) { return lua_newstate(l_alloc, NULL); } diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 8e476a526..a6e7b4940 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -385,8 +385,7 @@ static int math_sign(lua_State* L) static int math_round(lua_State* L) { - double v = luaL_checknumber(L, 1); - lua_pushnumber(L, round(v)); + lua_pushnumber(L, round(luaL_checknumber(L, 1))); return 1; } @@ -429,7 +428,7 @@ static const luaL_Reg mathlib[] = { /* ** Open math library */ -LUALIB_API int luaopen_math(lua_State* L) +int luaopen_math(lua_State* L) { uint64_t seed = uintptr_t(L); seed ^= time(NULL); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index d8b265cba..9f9d4a98f 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -33,11 +33,17 @@ #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : ms32) #endif +#if LUA_VECTOR_SIZE == 4 +static_assert(sizeof(TValue) == ABISWITCH(24, 24, 24), "size mismatch for value"); +static_assert(sizeof(LuaNode) == ABISWITCH(48, 48, 48), "size mismatch for table entry"); +#else static_assert(sizeof(TValue) == ABISWITCH(16, 16, 16), "size mismatch for value"); +static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); +#endif + static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); -static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 43f8014b4..67f832dc5 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -18,12 +18,20 @@ inline bool luai_veceq(const float* a, const float* b) { +#if LUA_VECTOR_SIZE == 4 + return a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3]; +#else return a[0] == b[0] && a[1] == b[1] && a[2] == b[2]; +#endif } inline bool luai_vecisnan(const float* a) { +#if LUA_VECTOR_SIZE == 4 + return a[0] != a[0] || a[1] != a[1] || a[2] != a[2] || a[3] != a[3]; +#else return a[0] != a[0] || a[1] != a[1] || a[2] != a[2]; +#endif } LUAU_FASTMATH_BEGIN diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index bf13e6e97..370c7b283 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -15,7 +15,7 @@ -const TValue luaO_nilobject_ = {{NULL}, LUA_TNIL}; +const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; int luaO_log2(unsigned int x) { diff --git a/VM/src/lobject.h b/VM/src/lobject.h index c5f2e2f4f..ba040af6c 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -47,7 +47,7 @@ typedef union typedef struct lua_TValue { Value value; - int extra; + int extra[LUA_EXTRA_SIZE]; int tt; } TValue; @@ -105,15 +105,28 @@ typedef struct lua_TValue i_o->tt = LUA_TNUMBER; \ } -#define setvvalue(obj, x, y, z) \ +#if LUA_VECTOR_SIZE == 4 +#define setvvalue(obj, x, y, z, w) \ { \ TValue* i_o = (obj); \ float* i_v = i_o->value.v; \ i_v[0] = (x); \ i_v[1] = (y); \ i_v[2] = (z); \ + i_v[3] = (w); \ i_o->tt = LUA_TVECTOR; \ } +#else +#define setvvalue(obj, x, y, z, w) \ + { \ + TValue* i_o = (obj); \ + float* i_v = i_o->value.v; \ + i_v[0] = (x); \ + i_v[1] = (y); \ + i_v[2] = (z); \ + i_o->tt = LUA_TVECTOR; \ + } +#endif #define setpvalue(obj, x) \ { \ @@ -364,7 +377,7 @@ typedef struct Closure typedef struct TKey { ::Value value; - int extra; + int extra[LUA_EXTRA_SIZE]; unsigned tt : 4; int next : 28; /* for chaining */ } TKey; @@ -381,7 +394,7 @@ typedef struct LuaNode LuaNode* n_ = (node); \ const TValue* i_o = (obj); \ n_->key.value = i_o->value; \ - n_->key.extra = i_o->extra; \ + memcpy(n_->key.extra, i_o->extra, sizeof(n_->key.extra)); \ n_->key.tt = i_o->tt; \ checkliveness(L->global, i_o); \ } @@ -392,7 +405,7 @@ typedef struct LuaNode TValue* i_o = (obj); \ const LuaNode* n_ = (node); \ i_o->value = n_->key.value; \ - i_o->extra = n_->key.extra; \ + memcpy(i_o->extra, n_->key.extra, sizeof(i_o->extra)); \ i_o->tt = n_->key.tt; \ checkliveness(L->global, i_o); \ } diff --git a/VM/src/loslib.cpp b/VM/src/loslib.cpp index 8eaef60cb..b5901865c 100644 --- a/VM/src/loslib.cpp +++ b/VM/src/loslib.cpp @@ -186,7 +186,7 @@ static const luaL_Reg syslib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_os(lua_State* L) +int luaopen_os(lua_State* L) { luaL_register(L, LUA_OSLIBNAME, syslib); return 1; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index b576f8093..0b3054ae4 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -1657,7 +1657,7 @@ static void createmetatable(lua_State* L) /* ** Open string library */ -LUALIB_API int luaopen_string(lua_State* L) +int luaopen_string(lua_State* L) { luaL_register(L, LUA_STRLIBNAME, strlib); createmetatable(L); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 07d22d596..0b55fceac 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -31,18 +31,19 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) #define MAXSIZE (1 << MAXBITS) static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect"); + // TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case -static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); -static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); -static_assert(TKey{{NULL}, 0, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); +static_assert(TKey{{NULL}, {0}, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); +static_assert(TKey{{NULL}, {0}, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); +static_assert(TKey{{NULL}, {0}, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); // reset cache of absent metamethods, cache is updated in luaT_gettm #define invalidateTMcache(t) t->flags = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { - {{NULL}, 0, LUA_TNIL}, /* value */ - {{NULL}, 0, LUA_TNIL, 0} /* key */ + {{NULL}, {0}, LUA_TNIL}, /* value */ + {{NULL}, {0}, LUA_TNIL, 0} /* key */ }; #define dummynode (&luaH_dummynode) @@ -96,7 +97,7 @@ static LuaNode* hashnum(const Table* t, double n) static LuaNode* hashvec(const Table* t, const float* v) { - unsigned int i[3]; + unsigned int i[LUA_VECTOR_SIZE]; memcpy(i, v, sizeof(i)); // convert -0 to 0 to make sure they hash to the same value @@ -112,6 +113,12 @@ static LuaNode* hashvec(const Table* t, const float* v) // Optimized Spatial Hashing for Collision Detection of Deformable Objects unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); +#if LUA_VECTOR_SIZE == 4 + i[3] = (i[3] == 0x8000000) ? 0 : i[3]; + i[3] ^= i[3] >> 17; + h ^= i[3] * 39916801; +#endif + return hashpow2(t, h); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 370258189..0d3374efa 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -527,7 +527,7 @@ static const luaL_Reg tab_funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_table(lua_State* L) +int luaopen_table(lua_State* L) { luaL_register(L, LUA_TABLIBNAME, tab_funcs); diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 378de3d0d..8bc8200a5 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -283,7 +283,7 @@ static const luaL_Reg funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_utf8(lua_State* L) +int luaopen_utf8(lua_State* L) { luaL_register(L, LUA_UTF8LIBNAME, funcs); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index eed2862b1..bf8d493eb 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -601,7 +601,13 @@ static void luau_execute(lua_State* L) const char* name = getstr(tsvalue(kv)); int ic = (name[0] | ' ') - 'x'; - if (unsigned(ic) < 3 && name[1] == '\0') +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') { setnvalue(ra, rb->value.v[ic]); VM_NEXT(); @@ -1526,7 +1532,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); VM_NEXT(); } else @@ -1572,7 +1578,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); VM_NEXT(); } else @@ -1618,21 +1624,21 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(rc)); - setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); VM_NEXT(); } else if (ttisvector(rb) && ttisvector(rc)) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); VM_NEXT(); } else if (ttisnumber(rb) && ttisvector(rc)) { float vb = cast_to(float, nvalue(rb)); const float* vc = rc->value.v; - setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2]); + setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2], vb * vc[3]); VM_NEXT(); } else @@ -1679,21 +1685,21 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(rc)); - setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); VM_NEXT(); } else if (ttisvector(rb) && ttisvector(rc)) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); VM_NEXT(); } else if (ttisnumber(rb) && ttisvector(rc)) { float vb = cast_to(float, nvalue(rb)); const float* vc = rc->value.v; - setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2]); + setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2], vb / vc[3]); VM_NEXT(); } else @@ -1826,7 +1832,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(kv)); - setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); VM_NEXT(); } else @@ -1872,7 +1878,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(kv)); - setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); VM_NEXT(); } else @@ -2037,7 +2043,7 @@ static void luau_execute(lua_State* L) else if (ttisvector(rb)) { const float* vb = rb->value.v; - setvvalue(ra, -vb[0], -vb[1], -vb[2]); + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); VM_NEXT(); } else diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index a168b6522..add3588d4 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -9,6 +9,7 @@ #include "lgc.h" #include "lmem.h" #include "lbytecode.h" +#include "lapi.h" #include @@ -162,9 +163,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; - // env is 0 for current environment and a stack relative index otherwise - LUAU_ASSERT(env <= 0 && L->top - L->base >= -env); - Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(L->top + env); + // env is 0 for current environment and a stack index otherwise + Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index f52e8e745..740a4cfd2 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -401,19 +401,19 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_ADD: - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); return; case TM_SUB: - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); return; case TM_MUL: - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); return; case TM_DIV: - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); return; case TM_UNM: - setvvalue(ra, -vb[0], -vb[1], -vb[2]); + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); return; default: break; @@ -430,10 +430,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_MUL: - setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc); + setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); return; case TM_DIV: - setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc); + setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); return; default: break; @@ -451,10 +451,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_MUL: - setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2]); + setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); return; case TM_DIV: - setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2]); + setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); return; default: break; diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index 4be785022..fe5b4eb2f 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index 4be785022..fe5b4eb2f 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index de962a74f..79cbe38ba 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -275,7 +275,7 @@ local function memory(s) local t=os.clock() --local dt=string.format("%f",t-t0) local dt=t-t0 - --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage"count"/1024),"M\n") + --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage("count")/1024),"M\n") t0=t end @@ -286,7 +286,7 @@ local function do_(f,s) end local function julia(l,a,b) -memory"begin" +memory("begin") cx=a cy=b root=newcell() exterior=newcell() exterior.color=white @@ -297,14 +297,14 @@ memory"begin" do_(update,"update") repeat N=0 color(root,Rxmin,Rxmax,Rymin,Rymax) --print("color",N) - until N==0 memory"color" + until N==0 memory("color") repeat N=0 prewhite(root,Rxmin,Rxmax,Rymin,Rymax) --print("prewhite",N) - until N==0 memory"prewhite" + until N==0 memory("prewhite") do_(recolor,"recolor") do_(colorup,"colorup") --print("colorup",N) local g,b=do_(area,"area") --print("area",g,b,g+b) - show(i) memory"output" + show(i) memory("output") --print("edges",nE) end end diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index ae2399e49..27e534927 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -23,9 +23,11 @@ const bool kFuzzCompiler = true; const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; -const bool kFuzzTypes = true; const bool kFuzzTranspile = true; +// Should we generate type annotations? +const bool kFuzzTypes = true; + static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); std::string protoprint(const luau::StatBlock& stat, bool types); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index aa53a92b2..2090b0148 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -78,3 +78,26 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("AstQuery"); + +TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") +{ + ScopedFastFlag luauTailArgumentTypeInfo{"LuauTailArgumentTypeInfo", true}; + + check(R"( +local function foo() return 2 end +local function bar(a: number) return -a end +bar(foo()) + )"); + + auto oty = findTypeAtPosition(Position(3, 7)); + REQUIRE(oty); + CHECK_EQ("number", toString(*oty)); + + auto expectedOty = findExpectedTypeAtPosition(Position(3, 7)); + REQUIRE(expectedOty); + CHECK_EQ("number", toString(*expectedOty)); +} + +TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 5a7c86023..3b74a99e4 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1935,6 +1935,39 @@ return target(b@1 CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); + + check(R"( +local function bar(a: number) return -a end +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + +TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); + + check(R"( +local function foo() return 1 end +local function bar(a: number) return -a end +local abc = bar(@1) + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("foo")); + CHECK(ac.entryMap["foo"].parens == ParenthesesRecommendation::CursorAfter); +} + TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( @@ -2210,8 +2243,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require") { - ScopedFastFlag luauResolveModuleNameWithoutACurrentModule("LuauResolveModuleNameWithoutACurrentModule", true); - std::string_view source = R"( local a = require(w -- Line 1 -- | Column 27 @@ -2287,8 +2318,6 @@ until TEST_CASE_FIXTURE(ACFixture, "if_then_else_elseif_completions") { - ScopedFastFlag sff{"ElseElseIfCompletionImprovements", true}; - check(R"( local elsewhere = false @@ -2585,9 +2614,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - check(R"( type A = () -> T... local a: A<(number, s@1> diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 4ce8d08ae..6ba39adab 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1057,6 +1057,18 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return if false then 10 else 20"), R"( LOADN R0 20 RETURN R0 1 +)"); + + // codegen for a true constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if true then {} else error()"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 +)"); + + // codegen for a false constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if false then error() else {}"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 )"); // codegen for a false (in this case 'nil') constant condition @@ -2360,6 +2372,58 @@ Foo:Bar( )"); } +TEST_CASE("DebugLineInfoCallChain") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo = ... + +Foo +:Bar(1) +:Baz(2) +.Qux(3) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 1 +5: LOADN R4 1 +5: NAMECALL R2 R0 K0 +5: CALL R2 2 1 +6: LOADN R4 2 +6: NAMECALL R2 R2 K1 +6: CALL R2 2 1 +7: GETTABLEKS R1 R2 K2 +7: LOADN R2 3 +7: CALL R1 1 0 +8: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoFastCall") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo, Bar = ... + +return + math.max( + Foo, + Bar) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 2 +5: FASTCALL2 18 R0 R1 +5 +5: MOVE R3 R0 +5: MOVE R4 R1 +5: GETIMPORT R2 2 +5: CALL R2 2 -1 +5: RETURN R2 -1 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -3742,4 +3806,108 @@ RETURN R0 0 )"); } +TEST_CASE("ConstantsNoFolding") +{ + const char* source = "return nil, true, 42, 'hello'"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 0; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 +LOADK R3 K1 +RETURN R0 4 +)"); +} + +TEST_CASE("VectorFastCall") +{ + const char* source = "return Vector3.new(1, 2, 3)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +FASTCALL 54 +2 +GETIMPORT R0 2 +CALL R0 3 -1 +RETURN R0 -1 +)"); +} + +TEST_CASE("TypeAssertion") +{ + // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated + CHECK_EQ("\n" + compileFunction0(R"( +print(foo() :: typeof(error("compile time"))) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + // note that above, foo() is treated as single-arg function; removing type assertion changes the bytecode + CHECK_EQ("\n" + compileFunction0(R"( +print(foo()) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 -1 +CALL R0 -1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("Arithmetics") +{ + // basic arithmetics codegen with non-constants + CHECK_EQ("\n" + compileFunction0(R"( +local a, b = ... +return a + b, a - b, a / b, a * b, a % b, a ^ b +)"), + R"( +GETVARARGS R0 2 +ADD R2 R0 R1 +SUB R3 R0 R1 +DIV R4 R0 R1 +MUL R5 R0 R1 +MOD R6 R0 R1 +POW R7 R0 R1 +RETURN R2 6 +)"); + + // basic arithmetics codegen with constants on the right side + // note that we don't simplify these expressions as we don't know the type of a + CHECK_EQ("\n" + compileFunction0(R"( +local a = ... +return a + 1, a - 1, a / 1, a * 1, a % 1, a ^ 1 +)"), + R"( +GETVARARGS R0 1 +ADDK R1 R0 K0 +SUBK R2 R0 K0 +DIVK R3 R0 K0 +MULK R4 R0 K0 +MODK R5 R0 K0 +POWK R6 R0 K0 +RETURN R1 6 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e495a2136..b2aad3163 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -67,44 +67,42 @@ static int lua_vector(lua_State* L) double y = luaL_checknumber(L, 2); double z = luaL_checknumber(L, 3); +#if LUA_VECTOR_SIZE == 4 + double w = luaL_optnumber(L, 4, 0.0); + lua_pushvector(L, float(x), float(y), float(z), float(w)); +#else lua_pushvector(L, float(x), float(y), float(z)); +#endif return 1; } static int lua_vector_dot(lua_State* L) { - const float* a = lua_tovector(L, 1); - const float* b = lua_tovector(L, 2); - - if (a && b) - { - lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); - return 1; - } + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); - throw std::runtime_error("invalid arguments to vector:Dot"); + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); + return 1; } static int lua_vector_index(lua_State* L) { + const float* v = luaL_checkvector(L, 1); const char* name = luaL_checkstring(L, 2); - if (const float* v = lua_tovector(L, 1)) + if (strcmp(name, "Magnitude") == 0) { - if (strcmp(name, "Magnitude") == 0) - { - lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); - return 1; - } + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); + return 1; + } - if (strcmp(name, "Dot") == 0) - { - lua_pushcfunction(L, lua_vector_dot, "Dot"); - return 1; - } + if (strcmp(name, "Dot") == 0) + { + lua_pushcfunction(L, lua_vector_dot, "Dot"); + return 1; } - throw std::runtime_error(Luau::format("%s is not a valid member of vector", name)); + luaL_error(L, "%s is not a valid member of vector", name); } static int lua_vector_namecall(lua_State* L) @@ -115,7 +113,7 @@ static int lua_vector_namecall(lua_State* L) return lua_vector_dot(L); } - throw std::runtime_error(Luau::format("%s is not a valid method of vector", luaL_checkstring(L, 1))); + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); } int lua_silence(lua_State* L) @@ -373,11 +371,17 @@ TEST_CASE("Pack") TEST_CASE("Vector") { + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + runConformance("vector.lua", [](lua_State* L) { lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); @@ -504,6 +508,9 @@ TEST_CASE("Debugger") cb->debugbreak = [](lua_State* L, lua_Debug* ar) { breakhits++; + // make sure we can trace the stack for every breakpoint we hit + lua_debugtrace(L); + // for every breakpoint, we break on the first invocation and continue on second // this allows us to easily step off breakpoints // (real implementaiton may require singlestepping) @@ -524,7 +531,7 @@ TEST_CASE("Debugger") L, [](lua_State* L) -> int { int line = luaL_checkinteger(L, 1); - bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; + bool enabled = luaL_optboolean(L, 2, true); lua_Debug ar = {}; lua_getinfo(L, 1, "f", &ar); @@ -699,21 +706,52 @@ TEST_CASE("ApiFunctionCalls") StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_call(L, 2, 1); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_call + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_call(L, 2, 1); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_pcall(L, 2, 1, 0); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_pcall + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_pcall(L, 2, 1, 0); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } + + // lua_equal with a sleeping thread wake up + { + ScopedFastFlag luauActivateBeforeExec("LuauActivateBeforeExec", true); + + lua_State* L2 = lua_newthread(L); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + // Reset GC + lua_gc(L2, LUA_GCCOLLECT, 0); + + // Try to mark 'L2' as sleeping + // Can't control GC precisely, even in tests + lua_gc(L2, LUA_GCSTEP, 8); + + CHECK(lua_equal(L2, -1, -2) == 1); + lua_pop(L2, 2); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -727,8 +765,6 @@ static bool endsWith(const std::string& str, const std::string& suffix) #if !LUA_USE_LONGJMP TEST_CASE("ExceptionObject") { - ScopedFastFlag sff("LuauExceptionMessageFix", true); - struct ExceptionResult { bool exceptionGenerated; diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 29c33f7c1..36d6f5612 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -19,19 +19,6 @@ static const char* mainModuleName = "MainModule"; namespace Luau { -std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const -{ - auto g = expr->as(); - if (!g) - return std::nullopt; - - std::string_view value = g->name.value; - if (value == "game" || value == "Game" || value == "workspace" || value == "Workspace" || value == "script" || value == "Script") - return ModuleName(value); - - return std::nullopt; -} - std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) { if (AstExprGlobal* g = expr->as()) @@ -81,24 +68,6 @@ std::optional TestFileResolver::resolveModule(const ModuleInfo* cont return std::nullopt; } -ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const -{ - return lhs + "/" + ModuleName(rhs); -} - -std::optional TestFileResolver::getParentModuleName(const ModuleName& name) const -{ - std::string_view view = name; - const size_t lastSeparatorIndex = view.find_last_of('/'); - - if (lastSeparatorIndex != std::string_view::npos) - { - return ModuleName(view.substr(0, lastSeparatorIndex)); - } - - return std::nullopt; -} - std::string TestFileResolver::getHumanReadableModuleName(const ModuleName& name) const { return name; @@ -324,6 +293,13 @@ std::optional Fixture::findTypeAtPosition(Position position) return Luau::findTypeAtPosition(*module, *sourceModule, position); } +std::optional Fixture::findExpectedTypeAtPosition(Position position) +{ + ModulePtr module = getMainModule(); + SourceModule* sourceModule = getMainSourceModule(); + return Luau::findExpectedTypeAtPosition(*module, *sourceModule, position); +} + TypeId Fixture::requireTypeAtPosition(Position position) { auto ty = findTypeAtPosition(position); diff --git a/tests/Fixture.h b/tests/Fixture.h index 1480a7f6a..de2b7381e 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -64,12 +64,8 @@ struct TestFileResolver return SourceCode{it->second, sourceType}; } - std::optional fromAstFragment(AstExpr* expr) const override; std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; - std::optional getParentModuleName(const ModuleName& name) const override; - std::string getHumanReadableModuleName(const ModuleName& name) const override; std::optional getEnvironmentForModule(const ModuleName& name) const override; @@ -126,6 +122,7 @@ struct Fixture std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); + std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index fbfec6367..51fcd3d6c 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -46,18 +46,6 @@ NaiveModuleResolver naiveModuleResolver; struct NaiveFileResolver : NullFileResolver { - std::optional fromAstFragment(AstExpr* expr) const override - { - AstExprGlobal* g = expr->as(); - if (g && g->name == "Modules") - return "Modules"; - - if (g && g->name == "game") - return "game"; - - return std::nullopt; - } - std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override { if (AstExprGlobal* g = expr->as()) @@ -86,11 +74,6 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } - - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + ModuleName(rhs); - } }; } // namespace diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 7ba40c503..1d13df289 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1469,6 +1469,22 @@ _ = true and true or false -- no warning since this is is a common pattern used CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); } +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") +{ + LintResult result = lint(R"( +local correct, opaque = ... + +if correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) then +end +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 4"); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 5); +} + TEST_CASE_FIXTURE(Fixture, "DuplicateLocal") { LintResult result = lint(R"( diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7a3543c7c..2800d2fe6 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -44,9 +44,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } @@ -56,12 +57,13 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -89,9 +91,10 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks); + TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -142,11 +145,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks); + TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -159,11 +163,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks); + TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -188,8 +193,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks); + TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -211,16 +217,16 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TypeArena dest; SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ("any", toString(clonedTy)); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); - encounteredFreeType = false; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); + cloneState = {}; + TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ("...any", toString(clonedTp)); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); } TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") @@ -232,12 +238,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TypeArena dest; SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); const TableTypeVar* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); } TEST_CASE_FIXTURE(Fixture, "clone_self_property") @@ -267,4 +273,34 @@ TEST_CASE_FIXTURE(Fixture, "clone_self_property") "dot or pass 1 extra nil to suppress this warning"); } +TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") +{ +#if defined(_DEBUG) || defined(_NOOPT) + int limit = 250; +#else + int limit = 500; +#endif + ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + + TypeArena src; + + TypeId table = src.addType(TableTypeVar{}); + TypeId nested = table; + + for (unsigned i = 0; i < limit + 100; i++) + { + TableTypeVar* ttv = getMutable(nested); + + ttv->props["a"].type = src.addType(TableTypeVar{}); + nested = ttv->props["a"].type; + } + + TypeArena dest; + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + CloneState cloneState; + + CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e3e6ce6d8..72d3a9a64 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2518,8 +2518,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") { - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - AstStat* stat = parse(R"( type Packed = () -> T... diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp new file mode 100644 index 000000000..0ca9c9949 --- /dev/null +++ b/tests/ToDot.test.cpp @@ -0,0 +1,366 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" +#include "Luau/ToDot.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +struct ToDotClassFixture : Fixture +{ + ToDotClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId baseClassMetaType = arena.addType(TableTypeVar{}); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}}); + getMutable(baseClassInstanceType)->props = { + {"BaseField", {typeChecker.numberType}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}}); + getMutable(childClassInstanceType)->props = { + {"ChildField", {typeChecker.stringType}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + + freeze(arena); + } +}; + +TEST_SUITE_BEGIN("ToDot"); + +TEST_CASE_FIXTURE(Fixture, "primitive") +{ + CheckResult result = check(R"( +local a: nil +local b: number +local c: any +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_NE("nil", toDot(requireType("a"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="number"]; +})", + toDot(requireType("b"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="any"]; +})", + toDot(requireType("c"))); + + ToDotOptions opts; + opts.showPointers = false; + opts.duplicatePrimitives = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="PrimitiveTypeVar number"]; +})", + toDot(requireType("b"), opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="AnyTypeVar 1"]; +})", + toDot(requireType("c"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound") +{ + CheckResult result = check(R"( +local a = 444 +local b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "function") +{ + ScopedFastFlag luauQuantifyInPlace2{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( +local function f(a, ...: string) return a end +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="BoundTypePack 6"]; +n6 -> n7; +n7 [label="TypePack 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "union") +{ + CheckResult result = check(R"( +local a: string | number +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection") +{ + CheckResult result = check(R"( +local a: string & number -- uninhabited +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="IntersectionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table") +{ + CheckResult result = check(R"( +type A = { x: T, y: (U...) -> (), [string]: any } +local a: A +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar A"]; +n1 -> n2 [label="x"]; +n2 [label="number"]; +n1 -> n3 [label="y"]; +n3 [label="FunctionTypeVar 3"]; +n3 -> n4 [label="arg"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n3 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n1 -> n7 [label="[index]"]; +n7 [label="string"]; +n1 -> n8 [label="[value]"]; +n8 [label="any"]; +n1 -> n9 [label="typeParam"]; +n9 [label="number"]; +n1 -> n4 [label="typePackParam"]; +})", + toDot(requireType("a"), opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + CheckResult result = check(R"( +local a: typeof(setmetatable({}, {})) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="MetatableTypeVar 1"]; +n1 -> n2 [label="table"]; +n2 [label="TableTypeVar 2"]; +n1 -> n3 [label="metatable"]; +n3 [label="TableTypeVar 3"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free") +{ + TypeVar type{TypeVariant{FreeTypeVar{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error") +{ + TypeVar type{TypeVariant{ErrorTypeVar{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "generic") +{ + TypeVar type{TypeVariant{GenericTypeVar{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypeVar T"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(ToDotClassFixture, "class") +{ + CheckResult result = check(R"( +local a: ChildClass +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ClassTypeVar ChildClass"]; +n1 -> n2 [label="ChildField"]; +n2 [label="string"]; +n1 -> n3 [label="[parent]"]; +n3 [label="ClassTypeVar BaseClass"]; +n3 -> n4 [label="BaseField"]; +n4 [label="number"]; +n3 -> n5 [label="[metatable]"]; +n5 [label="TableTypeVar 5"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free_pack") +{ + TypePackVar pack{TypePackVariant{FreeTypePack{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypePack 1"]; +})", + toDot(&pack, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error_pack") +{ + TypePackVar pack{TypePackVariant{Unifiable::Error{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypePack 1"]; +})", + toDot(&pack, opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(&pack); +} + +TEST_CASE_FIXTURE(Fixture, "generic_pack") +{ + TypePackVar pack1{TypePackVariant{GenericTypePack{}}}; + TypePackVar pack2{TypePackVariant{GenericTypePack{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack 1"]; +})", + toDot(&pack1, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack T"]; +})", + toDot(&pack2, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_pack") +{ + TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypePack 1"]; +n1 -> n2; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="number"]; +})", + toDot(&bound, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_table") +{ + CheckResult result = check(R"( +local a = {x=2} +local b +b.x = 2 +b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar 1"]; +n1 -> n2 [label="boundTo"]; +n2 [label="TableTypeVar a"]; +n2 -> n3 [label="x"]; +n3 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 928c03a31..327fa0bbd 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -445,9 +445,6 @@ local a: Import.Type TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - std::string code = R"( type Packed = (T...)->(T...) local a: Packed<> diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 74ce155c2..822bd727e 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -537,8 +537,6 @@ TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") { - ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; - CheckResult result = check(R"( type Array = { [number]: T } type Tuple = Array diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 88c2dc85d..aba508918 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -609,8 +609,6 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") { - ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; - CheckResult result = check(R"( local exports = {} local nested = {} @@ -627,4 +625,23 @@ return exports LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") +{ + ScopedFastFlag luauFunctionArgumentNameSize{"LuauFunctionArgumentNameSize", true}; + + CheckResult result = check(R"( +local function f(a: T, ...: U...) end + +f(1, 2, 3) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = findTypeAtPosition(Position(3, 0)); + REQUIRE(ty); + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e5c14dde8..e6d3d4d47 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -31,8 +31,6 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - const std::string code = R"( function f(a) if type(a) == "boolean" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index c3694be79..cb72faaf4 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2022,4 +2022,74 @@ caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local tmp = { p = { x = 5, y = 7 }} +local a: HasSuper = tmp +a.p = { x = 9 } +-- needs to be an error because +local y: number = tmp.p.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' +caused by: + Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { [string] : Super } +type HasSub = { [string] : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 99fd8339c..e3222a410 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4779,4 +4779,24 @@ local bar = foo.nutrition + 100 // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); } +TEST_CASE_FIXTURE(Fixture, "require_failed_module") +{ + ScopedFastFlag luauModuleRequireErrorPack{"LuauModuleRequireErrorPack", true}; + + fileResolver.source["game/A"] = R"( +return unfortunately() + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(aResult); + + CheckResult result = check(R"( +local ModuleA = require(game.A) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional oty = requireType("ModuleA"); + CHECK_EQ("*unknown*", toString(*oty)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index c6de0abf5..3f4420cda 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -296,9 +296,6 @@ end TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed = (T...) -> T... local a: Packed<> @@ -360,9 +357,6 @@ local c: Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } return {} @@ -393,9 +387,6 @@ local d: { a: typeof(c) } TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } return {} @@ -431,9 +422,6 @@ type C = Import.Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed1 = (T...) -> (T...) type Packed2 = (Packed1, T...) -> (Packed1, T...) @@ -452,9 +440,6 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = (T...) -> (string, T...) @@ -470,9 +455,6 @@ type E = X<(number, ...string)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) type A = Y @@ -501,9 +483,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = (T...) -> (T...) @@ -527,9 +506,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) @@ -549,9 +525,6 @@ type D = Y TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = { f: (T...) -> (U...) } @@ -567,9 +540,6 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = () -> T type Y = (T) -> U @@ -588,9 +558,6 @@ type C = Y<(number), boolean> TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed = (T, U) -> (V...) local b: Packed diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 91efa8188..13db923ed 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -3,6 +3,7 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -323,4 +324,48 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } +struct VisitCountTracker +{ + std::unordered_map tyVisits; + std::unordered_map tpVisits; + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + template + bool operator()(TypeId ty, const T& t) + { + tyVisits[ty]++; + return true; + } + + template + bool operator()(TypePackId tp, const T&) + { + tpVisits[tp]++; + return true; + } +}; + +TEST_CASE_FIXTURE(Fixture, "visit_once") +{ + CheckResult result = check(R"( +type T = { a: number, b: () -> () } +local b: (T, T, T) -> T +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId bType = requireType("b"); + + VisitCountTracker tester; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(bType, tester, seen); + + for (auto [_, count] : tester.tyVisits) + CHECK_EQ(count, 1); + + for (auto [_, count] : tester.tpVisits) + CHECK_EQ(count, 1); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 5e03b0559..7a4058b52 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -2,7 +2,13 @@ print('testing function calls through API') function add(a, b) - return a + b + return a + b +end + +local m = { __eq = function(a, b) return a.a == b.a end } + +function create_with_tm(x) + return setmetatable({ a = x }, m) end return('OK') diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 687fff1ee..188b8ebc4 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -441,7 +441,8 @@ assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r en assert((function() a = {} b = {} function eq(l, r) return #l == #r end setmetatable(a, {__eq = eq}) setmetatable(b, {__eq = eq}) return concat(a == b, a ~= b) end)() == "true,false") assert((function() a = {} b = {} setmetatable(a, {__eq = function(l, r) return #l == #r end}) setmetatable(b, {__eq = function(l, r) return #l == #r end}) return concat(a == b, a ~= b) end)() == "false,true") --- userdata, reference equality (no mt) +-- userdata, reference equality (no mt or mt.__eq) +assert((function() a = newproxy() return concat(a == newproxy(),a ~= newproxy()) end)() == "false,true") assert((function() a = newproxy(true) return concat(a == newproxy(true),a ~= newproxy(true)) end)() == "false,true") -- rawequal @@ -876,4 +877,4 @@ assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number testgetfenv() -- DONT MOVE THIS LINE -return'OK' +return 'OK' diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index aac42c56f..f32d5bdcb 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -419,11 +419,5 @@ co = coroutine.create(function () return loadstring("return a")() end) -a = {a = 15} --- debug.setfenv(co, a) --- assert(debug.getfenv(co) == a) --- assert(select(2, coroutine.resume(co)) == a) --- assert(select(2, coroutine.resume(co)) == a.a) - -return'OK' +return 'OK' diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua index 16c63b00f..f133501f1 100644 --- a/tests/conformance/constructs.lua +++ b/tests/conformance/constructs.lua @@ -237,4 +237,4 @@ repeat i = i+1 until i==c -return'OK' +return 'OK' diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 4d9b12953..f2ecc96bb 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -373,4 +373,4 @@ do assert(f() == 42) end -return'OK' +return 'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index 21ef60d74..ca35cf2f1 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -74,4 +74,4 @@ assert(os.difftime(t1,t2) == 60*2-19) assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) -return'OK' +return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index ee79a14fd..9cf3c7423 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,4 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") -return'OK' +return 'OK' diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index eded14e9d..d5ff215b4 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -34,15 +34,15 @@ assert(doit("error('hi', 0)") == 'hi') assert(doit("unpack({}, 1, n=2^30)")) assert(doit("a=math.sin()")) assert(not doit("tostring(1)") and doit("tostring()")) -assert(doit"tonumber()") -assert(doit"repeat until 1; a") +assert(doit("tonumber()")) +assert(doit("repeat until 1; a")) checksyntax("break label", "", "label", 1) -assert(doit";") -assert(doit"a=1;;") -assert(doit"return;;") -assert(doit"assert(false)") -assert(doit"assert(nil)") -assert(doit"a=math.sin\n(3)") +assert(doit(";")) +assert(doit("a=1;;")) +assert(doit("return;;")) +assert(doit("assert(false)")) +assert(doit("assert(nil)")) +assert(doit("a=math.sin\n(3)")) assert(doit("function a (... , ...) end")) assert(doit("function a (, ...) end")) @@ -59,7 +59,7 @@ checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)", "local 'bbbb'") checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") -assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) +assert(not string.find(doit("a={13}; local bbbb=1; a[bbbb](3)"), "'bbbb'")) checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") aaa = nil @@ -67,14 +67,14 @@ checkmessage("aaa.bbb:ddd(9)", "global 'aaa'") checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") -assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") +assert(not doit("local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)")) checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") checkmessage("aaa={}; x=3/aaa", "global 'aaa'") checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") checkmessage("aaa={}; x=-aaa", "global 'aaa'") -assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) -assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) +assert(not string.find(doit("aaa={}; x=(aaa or aaa)+(aaa and aaa)"), "'aaa'")) +assert(not string.find(doit("aaa={}; (aaa or aaa)()"), "'aaa'")) checkmessage([[aaa=9 repeat until 3==3 @@ -122,10 +122,10 @@ function lineerror (s) return line and line+0 end -assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) --- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) --- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) -assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) +assert(lineerror("local a\n for i=1,'a' do \n print(i) \n end") == 2) +-- assert(lineerror("\n local a \n for k,v in 3 \n do \n print(k) \n end") == 3) +-- assert(lineerror("\n\n for k,v in \n 3 \n do \n print(k) \n end") == 4) +assert(lineerror("function a.x.y ()\na=a+1\nend") == 1) local p = [[ function g() f() end diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 4263dfda7..6d9eb8544 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -77,7 +77,7 @@ end local function dosteps (siz) collectgarbage() - collectgarbage"stop" + collectgarbage("stop") local a = {} for i=1,100 do a[i] = {{}}; local b = {} end local x = gcinfo() @@ -99,11 +99,11 @@ assert(dosteps(10000) == 1) do local x = gcinfo() collectgarbage() - collectgarbage"stop" + collectgarbage("stop") repeat local a = {} until gcinfo() > 1000 - collectgarbage"restart" + collectgarbage("restart") repeat local a = {} until gcinfo() < 1000 @@ -123,7 +123,7 @@ for n in pairs(b) do end b = nil collectgarbage() -for n in pairs(a) do error'cannot be here' end +for n in pairs(a) do error("cannot be here") end for i=1,lim do a[i] = i end for i=1,lim do assert(a[i] == i) end diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index 7f9b75962..94ba5ccfb 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,9 +368,9 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error"not here" end -for i=1,0 do error'not here' end -for i=0,1,-1 do error'not here' end +for a,b in pairs{} do error("not here") end +for i=1,0 do error("not here") end +for i=0,1,-1 do error("not here") end a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index a2072d2c8..84ac2ba19 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,4 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) -return'OK' +return 'OK' diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.lua index 024cb16d7..bfd7a1ac8 100644 --- a/tests/conformance/utf8.lua +++ b/tests/conformance/utf8.lua @@ -205,4 +205,4 @@ for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do end end -return'OK' +return 'OK' diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 620f646aa..7d18bda33 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -1,6 +1,9 @@ -- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details print('testing vectors') +-- detect vector size +local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 + -- equality assert(vector(1, 2, 3) == vector(1, 2, 3)) assert(vector(0, 1, 2) == vector(-0, 1, 2)) @@ -13,8 +16,14 @@ assert(not rawequal(vector(1, 2, 3), vector(1, 2, 4))) -- type & tostring assert(type(vector(1, 2, 3)) == "vector") -assert(tostring(vector(1, 2, 3)) == "1, 2, 3") -assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") + +if vector_size == 4 then + assert(tostring(vector(1, 2, 3, 4)) == "1, 2, 3, 4") + assert(tostring(vector(-1, 2, 0.5, 0)) == "-1, 2, 0.5, 0") +else + assert(tostring(vector(1, 2, 3)) == "1, 2, 3") + assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") +end local t = {} @@ -42,12 +51,19 @@ assert(8 * vector(8, 16, 24) == vector(64, 128, 192)); assert(vector(1, 2, 4) * '8' == vector(8, 16, 32)); assert('8' * vector(8, 16, 24) == vector(64, 128, 192)); -assert(vector(1, 2, 4) / vector(8, 16, 24) == vector(1/8, 2/16, 4/24)); +if vector_size == 4 then + assert(vector(1, 2, 4, 8) / vector(8, 16, 24, 32) == vector(1/8, 2/16, 4/24, 8/32)); + assert(8 / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); + assert('8' / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); +else + assert(vector(1, 2, 4) / vector(8, 16, 24, 1) == vector(1/8, 2/16, 4/24)); + assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); + assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); +end + assert(vector(1, 2, 4) / 8 == vector(1/8, 1/4, 1/2)); assert(vector(1, 2, 4) / (1 / val) == vector(1/8, 2/8, 4/8)); -assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(vector(1, 2, 4) / '8' == vector(1/8, 1/4, 1/2)); -assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(-vector(1, 2, 4) == vector(-1, -2, -4)); @@ -71,4 +87,9 @@ assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == fal -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) +-- additional checks for 4-component vectors +if vector_size == 4 then + assert(vector(1, 2, 3, 4).w == 4) +end + return 'OK' diff --git a/tools/svg.py b/tools/svg.py index 3b3bb28c4..99853fb6e 100644 --- a/tools/svg.py +++ b/tools/svg.py @@ -458,13 +458,16 @@ def display(root, title, colors, flip = False): framewidth = 1200 - 20 + def pixels(x): + return float(x) / root.width * framewidth if root.width > 0 else 0 + for n in root.subtree(): - if n.width / root.width * framewidth < 0.1: + if pixels(n.width) < 0.1: continue - x = 10 + n.offset / root.width * framewidth + x = 10 + pixels(n.offset) y = (maxdepth - 1 - n.depth if flip else n.depth) * 16 + 3 * 16 - width = n.width / root.width * framewidth + width = pixels(n.width) height = 15 if colors == "cold": From e440729e2bb98aba0bdb41109e257de522b857e4 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 2 Dec 2021 15:46:33 -0800 Subject: [PATCH 09/14] Fix signed/unsigned mismatch warning + lower limit to match upstream --- tests/Module.test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe6..e3993cc53 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); From a8673f0f99885da92597d6efb4feaf0f14d59991 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 10 Dec 2021 13:17:10 -0800 Subject: [PATCH 10/14] Sync to upstream/release/507-pre This doesn't contain all changes for 507 yet but we might want to do the Luau 0.507 release a bit earlier to end the year sooner. --- Analysis/include/Luau/Error.h | 11 +- Analysis/include/Luau/IostreamHelpers.h | 4 + Analysis/include/Luau/ToString.h | 4 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeVar.h | 8 +- Analysis/include/Luau/Unifiable.h | 11 +- Analysis/include/Luau/Unifier.h | 3 + Analysis/src/Autocomplete.cpp | 115 +++++++-- Analysis/src/BuiltinDefinitions.cpp | 9 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 12 +- Analysis/src/Error.cpp | 17 +- Analysis/src/IostreamHelpers.cpp | 6 + Analysis/src/JsonEncoder.cpp | 4 +- Analysis/src/Module.cpp | 19 +- Analysis/src/Predicate.cpp | 2 +- Analysis/src/ToString.cpp | 38 ++- Analysis/src/TypeInfer.cpp | 189 +++++++++----- Analysis/src/TypeUtils.cpp | 6 +- Analysis/src/TypeVar.cpp | 220 ++++------------ Analysis/src/Unifier.cpp | 266 +++++++++++++++++--- Ast/src/Parser.cpp | 5 +- CLI/Analyze.cpp | 20 +- CLI/Coverage.cpp | 88 +++++++ CLI/Coverage.h | 10 + CLI/FileUtils.cpp | 4 + CLI/Profiler.cpp | 8 +- CLI/Profiler.h | 2 +- CLI/Repl.cpp | 35 ++- Compiler/src/Compiler.cpp | 22 +- Makefile | 6 +- Sources.cmake | 4 + VM/include/lua.h | 4 + VM/src/lapi.cpp | 162 ++++++------ VM/src/ldebug.cpp | 63 +++++ VM/src/lgc.cpp | 49 +--- VM/src/lgcdebug.cpp | 1 + VM/src/lobject.h | 10 +- VM/src/lstring.cpp | 29 --- VM/src/lstring.h | 7 - VM/src/ludata.cpp | 37 +++ VM/src/ludata.h | 13 + VM/src/lvmexecute.cpp | 14 +- VM/src/lvmutils.cpp | 4 +- bench/tests/sunspider/3d-raytrace.lua | 15 +- tests/Autocomplete.test.cpp | 55 +++- tests/Compiler.test.cpp | 49 +--- tests/Conformance.test.cpp | 100 ++++++-- tests/Module.test.cpp | 4 +- tests/Parser.test.cpp | 4 - tests/TypeInfer.annotations.test.cpp | 30 +++ tests/TypeInfer.builtins.test.cpp | 26 ++ tests/TypeInfer.generics.test.cpp | 40 +++ tests/TypeInfer.refinements.test.cpp | 17 +- tests/TypeInfer.singletons.test.cpp | 50 ++++ tests/TypeInfer.tables.test.cpp | 61 +++-- tests/TypeInfer.test.cpp | 239 +++++++++++++++++- tests/TypeInfer.tryUnify.test.cpp | 28 +++ tests/TypeInfer.unionTypes.test.cpp | 16 ++ tests/conformance/coverage.lua | 64 +++++ 59 files changed, 1703 insertions(+), 643 deletions(-) create mode 100644 CLI/Coverage.cpp create mode 100644 CLI/Coverage.h create mode 100644 VM/src/ludata.cpp create mode 100644 VM/src/ludata.h create mode 100644 tests/conformance/coverage.lua diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 9ee750043..aff3c4d9e 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -277,11 +277,20 @@ struct MissingUnionProperty bool operator==(const MissingUnionProperty& rhs) const; }; +struct TypesAreUnrelated +{ + TypeId left; + TypeId right; + + bool operator==(const TypesAreUnrelated& rhs) const; +}; + using TypeErrorData = Variant; + DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, + TypesAreUnrelated>; struct TypeError { diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index f9e9cd48c..ee994296c 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -36,6 +36,10 @@ std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); +std::ostream& operator<<(std::ostream& lhs, const SwappedGenericTypeParameter& error); +std::ostream& operator<<(std::ostream& lhs, const OptionalValueAccess& error); +std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error); +std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error); std::ostream& operator<<(std::ostream& lhs, const TableState& tv); std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 50379c1cd..a97bf6d6b 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,8 +69,8 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression -void dump(TypeId ty); -void dump(TypePackId ty); +std::string dump(TypeId ty); +std::string dump(TypePackId ty); std::string generateName(size_t n); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9f553bc14..451976e48 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -156,13 +156,14 @@ struct TypeChecker // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). // Note: the binding may be null. + // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); - TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); + TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalNameLoc, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); @@ -174,7 +175,7 @@ struct TypeChecker ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -277,7 +278,7 @@ struct TypeChecker [[noreturn]] void ice(const std::string& message); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); - ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + ScopePtr childScope(const ScopePtr& parent, const Location& location); // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8c4c2f34f..f6829ec3e 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -499,6 +499,7 @@ struct SingletonTypes const TypePackId anyTypePack; SingletonTypes(); + ~SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; @@ -509,10 +510,12 @@ struct SingletonTypes private: std::unique_ptr arena; + bool debugFreezeArena = false; + TypeId makeStringMetatable(); }; -extern SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes(); void persist(TypeId ty); void persist(TypePackId tp); @@ -523,9 +526,6 @@ TypeLevel* getMutableLevel(TypeId ty); const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); -bool hasGeneric(TypeId ty); -bool hasGeneric(TypePackId tp); - TypeVar* asMutable(TypeId ty); template diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index b47610fca..e8eafe688 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -24,7 +24,7 @@ struct TypeLevel int level = 0; int subLevel = 0; - // Returns true if the typelevel "this" is "bigger" than rhs + // Returns true if the level of "this" belongs to an equal or larger scope than that of rhs bool subsumes(const TypeLevel& rhs) const { if (level < rhs.level) @@ -38,6 +38,15 @@ struct TypeLevel return false; } + // Returns true if the level of "this" belongs to a larger (not equal) scope than that of rhs + bool subsumesStrict(const TypeLevel& rhs) const + { + if (level == rhs.level && subLevel == rhs.subLevel) + return false; + else + return subsumes(rhs); + } + TypeLevel incr() const { TypeLevel result; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4588cdd8c..7681b9662 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -91,6 +91,9 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + // Available after regular type pack unification errors + std::optional firstPackErrorPos; }; } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index db2d1d0e5..4b583792c 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -190,7 +191,48 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve return ParenthesesRecommendation::None; } -static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, TypeId ty) +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); + + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionTypeVar* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) { ty = follow(ty); @@ -220,15 +262,29 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return TypeCorrectKind::None; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return TypeCorrectKind::None; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return TypeCorrectKind::None; - TypeId expectedType = follow(*it); + expectedType = follow(*it); + } if (FFlag::LuauAutocompletePreferToCallFunctions) { @@ -333,8 +389,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (result.count(name) == 0 && name != Parser::errorName) { Luau::TypeId type = Luau::follow(prop.type); - TypeCorrectKind typeCorrect = - indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, nodes.back(), type); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -692,17 +748,31 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) return ret; } -static std::optional functionIsExpectedAt(const Module& module, AstNode* node) +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); - TypeId expectedType = follow(*it); + if (!typeAtPosition) + return std::nullopt; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + expectedType = follow(*it); + } if (get(expectedType)) return true; @@ -1171,7 +1241,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1181,9 +1251,10 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, typeChecker.booleanType); - TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; if (FFlag::LuauIfElseExpressionAnalysisSupport) result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bac94a2bd..d527414a1 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -217,9 +217,9 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + std::optional stringMetatableTy = getMetatable(getSingletonTypes().stringType); LUAU_ASSERT(stringMetatableTy); const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); @@ -271,7 +271,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) persist(pair.second.typeId); if (TableTypeVar* ttv = getMutable(pair.second.typeId)) - ttv->name = toString(pair.first); + { + if (!ttv->name) + ttv->name = toString(pair.first); + } } attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 9f5c82500..d0afa7424 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAGVARIABLE(LuauFixTonumberReturnType, false) + namespace Luau { @@ -113,7 +115,6 @@ declare function gcinfo(): number declare function error(message: T, level: number?) declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V @@ -204,7 +205,14 @@ declare function gcinfo(): number std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + if (FFlag::LuauFixTonumberReturnType) + result += "declare function tonumber(value: T, radix: number?): number?\n"; + else + result += "declare function tonumber(value: T, radix: number?): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 8334bd626..ce832c6b3 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -58,7 +58,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + ". "; + result += tm.reason + " "; result += Luau::toString(*tm.error); } @@ -410,6 +410,11 @@ struct ErrorConverter return ss + " in the type '" + toString(e.type) + "'"; } + + std::string operator()(const TypesAreUnrelated& e) const + { + return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; + } }; struct InvalidNameChecker @@ -658,6 +663,11 @@ bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const return *type == *rhs.type && key == rhs.key; } +bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const +{ + return left == rhs.left && right == rhs.right; +} + std::string toString(const TypeError& error) { ErrorConverter converter; @@ -793,6 +803,11 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& for (auto& ty : e.missing) ty = clone(ty); } + else if constexpr (std::is_same_v) + { + e.left = clone(e.left); + e.right = clone(e.right); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index ac46b5a49..5bc76ade5 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -262,6 +262,12 @@ std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error return stream << " }, key = '" + error.key + "' }"; } +std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) +{ + stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; + return stream; +} + std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index c7f623eea..23491a5a1 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -262,7 +262,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(a); } @@ -379,7 +379,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(prop); } }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b4b6eb425..e1e53c971 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau @@ -23,7 +22,7 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) @@ -194,7 +193,7 @@ struct TypePackCloner { cloneState.encounteredFreeType = true; - TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); TypePackId cloned = dest.addTypePack(*err); seenTypePacks[typePackId] = cloned; } @@ -247,7 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { cloneState.encounteredFreeType = true; - TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; } @@ -421,9 +420,6 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -440,12 +436,11 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks { TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // TODO: Make this work when the arena of 'res' might be frozen asMutable(res)->documentationSymbol = typeId->documentationSymbol; } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -508,8 +503,8 @@ bool Module::clonePublicInterface() if (moduleScope->varargPack) moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); - for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); + for (auto& [name, tf] : moduleScope->exportedTypeBindings) + tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 848627cf8..7bd8001e3 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -24,7 +24,7 @@ std::optional tryGetLValue(const AstExpr& node) else if (auto indexexpr = expr->as()) { if (auto lvalue = tryGetLValue(*indexexpr->expr)) - if (auto string = indexexpr->expr->as()) + if (auto string = indexexpr->index->as()) return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 6322096c4..a6be53482 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,6 +13,13 @@ LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) +/* + * Prefix generic typenames with gen- + * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 + * Fair warning: Setting this will break a lot of Luau unit tests. + */ +LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) + namespace Luau { @@ -290,7 +297,15 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(ftv.level.level)); + } } void operator()(TypeId, const BoundTypeVar& btv) @@ -802,6 +817,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("gen-"); if (pack.explicitName) { state.result.nameMap.typePacks[tp] = pack.name; @@ -817,7 +834,16 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(tp)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(pack.level.level)); + } + state.emit("..."); } @@ -1181,20 +1207,24 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } -void dump(TypeId ty) +std::string dump(TypeId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } -void dump(TypePackId ty) +std::string dump(TypePackId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } std::string generateName(size_t i) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 617bf482c..abbc2901b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -9,9 +9,9 @@ #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" -#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/ToString.h" #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" @@ -29,7 +29,6 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) -LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) @@ -37,6 +36,12 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) +LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) +LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) +LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -206,14 +211,14 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan : resolver(resolver) , iceHandler(iceHandler) , unifierState(iceHandler) - , nilType(singletonTypes.nilType) - , numberType(singletonTypes.numberType) - , stringType(singletonTypes.stringType) - , booleanType(singletonTypes.booleanType) - , threadType(singletonTypes.threadType) - , anyType(singletonTypes.anyType) - , optionalNumberType(singletonTypes.optionalNumberType) - , anyTypePack(singletonTypes.anyTypePack) + , nilType(getSingletonTypes().nilType) + , numberType(getSingletonTypes().numberType) + , stringType(getSingletonTypes().stringType) + , booleanType(getSingletonTypes().booleanType) + , threadType(getSingletonTypes().threadType) + , anyType(getSingletonTypes().anyType) + , optionalNumberType(getSingletonTypes().optionalNumberType) + , anyTypePack(getSingletonTypes().anyTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -443,7 +448,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name); + TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); unify(leftType, funTy, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -711,14 +716,15 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else if (auto tail = valueIter.tail()) { - if (get(*tail)) + TypePackId tailPack = follow(*tail); + if (get(tailPack)) right = errorRecoveryType(scope); - else if (auto vtp = get(*tail)) + else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(*tail)) + else if (get(tailPack)) { - *asMutable(*tail) = TypePack{{left}}; - growingPack = getMutable(*tail); + *asMutable(tailPack) = TypePack{{left}}; + growingPack = getMutable(tailPack); } } @@ -1107,8 +1113,27 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco unify(leftType, ty, function.location); - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + if (FFlag::LuauUpdateFunctionNameBinding) + { + LUAU_ASSERT(function.name->is() || function.name->is()); + + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + { + if (auto ttv = getMutableTableType(*typeIt)) + { + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); + } + } + } + } + else + { + if (leftTypeBinding) + *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + } } } @@ -1148,8 +1173,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = - FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); + ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + if (FFlag::LuauProperTypeLevels) + aliasScope->level.subLevel = subLevel; auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); @@ -1166,6 +1193,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ice("Not predeclared"); ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); for (TypeId ty : binding->typeParams) { @@ -1505,9 +1533,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!"); + ice("Unexpected abstract type pack!", expr.location); else - ice("Unknown TypePack type!"); + ice("Unknown TypePack type!", expr.location); } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) @@ -1574,7 +1602,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = freshType(scope); + TypeId result = FFlag::LuauAscribeCorrectLevelToInferredProperitesOfFreeTables ? freshType(tableType->level) : freshType(scope); tableType->props[name] = {result}; return result; } @@ -1738,7 +1766,16 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - return {checkLValue(scope, expr)}; + TypeId ty = checkLValue(scope, expr); + + if (FFlag::LuauRefiLookupFromIndexExpr) + { + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + } + + return {ty}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2421,12 +2458,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy TypeId annotationType = resolveType(scope, *expr.annotation); ExprResult result = checkExpr(scope, *expr.expr, annotationType); - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); - reportErrors(errorVec); - if (!errorVec.empty()) - annotationType = errorRecoveryType(annotationType); + if (FFlag::LuauBidirectionalAsExpr) + { + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(result.type, annotationType, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; + + if (canUnify(annotationType, result.type, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; - return {annotationType, std::move(result.predicates)}; + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; + } + else + { + ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); + if (!errorVec.empty()) + annotationType = errorRecoveryType(annotationType); + + return {annotationType, std::move(result.predicates)}; + } } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) @@ -2674,8 +2726,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope // Answers the question: "Can I define another function with this name?" // Primarily about detecting duplicates. -TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) +TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { + auto freshTy = [&]() { + if (FFlag::LuauProperTypeLevels) + return freshType(level); + else + return freshType(scope); + }; + if (auto globalName = funName.as()) { const ScopePtr& globalScope = currentModule->getModuleScope(); @@ -2689,7 +2748,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = freshType(scope); + TypeId ty = freshTy(); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2699,7 +2758,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {freshType(scope), funName.location}; + binding = {freshTy(), funName.location}; return binding.typeId; } @@ -2730,7 +2789,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = freshType(scope); + property.type = freshTy(); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -3327,7 +3386,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A fn = follow(fn); if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } @@ -3402,9 +3461,11 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { + LUAU_ASSERT(argLocations); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) @@ -3428,31 +3489,44 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{retPack}}; } - const FunctionTypeVar* ftv = get(fn); - if (!ftv) + std::vector metaArgLocations; + + // Might be a callable table + if (const MetatableTypeVar* mttv = get(fn)) { - // Might be a callable table - if (const MetatableTypeVar* mttv = get(fn)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) - { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - std::vector metaArgLocations = argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + if (FFlag::LuauFixRecursiveMetatableCall) + { + fn = instantiate(scope, *ty, expr.func->location); + + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; + } + else + { TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors); } } + } + const FunctionTypeVar* ftv = get(fn); + if (!ftv) + { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(retPack, errorRecoveryTypePack(scope), expr.func->location); return {{errorRecoveryTypePack(retPack)}}; @@ -3477,7 +3551,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -3772,7 +3846,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (moduleInfo.name.empty()) { - if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + if (currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); return errorRecoveryType(anyType); @@ -4268,9 +4342,11 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& } // Creates a new Scope and carries forward the varargs from the parent. -ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location, int subLevel) +ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { - ScopePtr scope = std::make_shared(parent, subLevel); + ScopePtr scope = std::make_shared(parent); + if (FFlag::LuauProperTypeLevels) + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4329,22 +4405,22 @@ TypeId TypeChecker::singletonType(std::string value) TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) { - return singletonTypes.errorRecoveryType(); + return getSingletonTypes().errorRecoveryType(); } TypeId TypeChecker::errorRecoveryType(TypeId guess) { - return singletonTypes.errorRecoveryType(guess); + return getSingletonTypes().errorRecoveryType(guess); } TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) { - return singletonTypes.errorRecoveryTypePack(); + return getSingletonTypes().errorRecoveryTypePack(); } TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) { - return singletonTypes.errorRecoveryTypePack(guess); + return getSingletonTypes().errorRecoveryTypePack(guess); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4547,6 +4623,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (const auto& func = annotation.as()) { ScopePtr funcScope = childScope(scope, func->location); + funcScope->level = scope->level.incr(); auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 0d9d91e0c..8c6d5e49f 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -19,7 +19,7 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa TypeId unwrapped = follow(*metatable); if (get(unwrapped)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) @@ -61,12 +61,12 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc { std::optional r = first(follow(itf->retType)); if (!r) - return singletonTypes.nilType; + return getSingletonTypes().nilType; else return *r; } else if (get(index)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 62715af53..571b13ca0 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -579,11 +580,25 @@ SingletonTypes::SingletonTypes() , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); - stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; + stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; persist(stringMetatable); + + debugFreezeArena = FFlag::DebugLuauFreezeArena; freeze(*arena); } +SingletonTypes::~SingletonTypes() +{ + // Destroy the arena with the same memory management flags it was created with + bool prevFlag = FFlag::DebugLuauFreezeArena; + FFlag::DebugLuauFreezeArena.value = debugFreezeArena; + + unfreeze(*arena); + arena.reset(nullptr); + + FFlag::DebugLuauFreezeArena.value = prevFlag; +} + TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); @@ -641,6 +656,9 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + if (TableTypeVar* ttv = getMutable(tableType)) + ttv->name = "string"; + return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -670,7 +688,11 @@ TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) return &errorTypePack_; } -SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes() +{ + static SingletonTypes singletonTypes; + return singletonTypes; +} void persist(TypeId ty) { @@ -719,6 +741,18 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto mtv = get(t)) + { + queue.push_back(mtv->table); + queue.push_back(mtv->metatable); + } + else if (get(t) || get(t) || get(t) || get(t) || get(t)) + { + } + else + { + LUAU_ASSERT(!"TypeId is not supported in a persist call"); + } } } @@ -736,6 +770,17 @@ void persist(TypePackId tp) if (p->tail) persist(*p->tail); } + else if (auto vtp = get(tp)) + { + persist(vtp->ty); + } + else if (get(tp)) + { + } + else + { + LUAU_ASSERT(!"TypePackId is not supported in a persist call"); + } } const TypeLevel* getLevel(TypeId ty) @@ -757,167 +802,6 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } -struct QVarFinder -{ - mutable DenseHashSet seen; - - QVarFinder() - : seen(nullptr) - { - } - - bool hasSeen(const void* tv) const - { - if (seen.contains(tv)) - return true; - - seen.insert(tv); - return false; - } - - bool hasGeneric(TypeId tid) const - { - if (hasSeen(&tid->ty)) - return false; - - return Luau::visit(*this, tid->ty); - } - - bool hasGeneric(TypePackId tp) const - { - if (hasSeen(&tp->ty)) - return false; - - return Luau::visit(*this, tp->ty); - } - - bool operator()(const Unifiable::Free&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const Unifiable::Generic&) const - { - return true; - } - bool operator()(const Unifiable::Error&) const - { - return false; - } - bool operator()(const PrimitiveTypeVar&) const - { - return false; - } - - bool operator()(const SingletonTypeVar&) const - { - return false; - } - - bool operator()(const FunctionTypeVar& ftv) const - { - if (hasGeneric(ftv.argTypes)) - return true; - return hasGeneric(ftv.retType); - } - - bool operator()(const TableTypeVar& ttv) const - { - if (ttv.state == TableState::Generic) - return true; - - if (ttv.indexer) - { - if (hasGeneric(ttv.indexer->indexType)) - return true; - if (hasGeneric(ttv.indexer->indexResultType)) - return true; - } - - for (const auto& [_name, prop] : ttv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - return false; - } - - bool operator()(const MetatableTypeVar& mtv) const - { - return hasGeneric(mtv.table) || hasGeneric(mtv.metatable); - } - - bool operator()(const ClassTypeVar& ctv) const - { - for (const auto& [name, prop] : ctv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - if (ctv.parent) - return hasGeneric(*ctv.parent); - - return false; - } - - bool operator()(const AnyTypeVar&) const - { - return false; - } - - bool operator()(const UnionTypeVar& utv) const - { - for (TypeId tid : utv.options) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const IntersectionTypeVar& utv) const - { - for (TypeId tid : utv.parts) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const LazyTypeVar&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const TypePack& pack) const - { - for (TypeId ty : pack.head) - if (hasGeneric(ty)) - return true; - - if (pack.tail) - return hasGeneric(*pack.tail); - - return false; - } - - bool operator()(const VariadicTypePack& pack) const - { - return hasGeneric(pack.ty); - } -}; - const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) @@ -953,16 +837,6 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -bool hasGeneric(TypeId ty) -{ - return Luau::visit(QVarFinder{}, ty->ty); -} - -bool hasGeneric(TypePackId tp) -{ - return Luau::visit(QVarFinder{}, tp->ty); -} - UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) { LUAU_ASSERT(utv); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d0b188374..c5aab8562 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,7 +14,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) @@ -22,9 +22,82 @@ LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { + +struct PromoteTypeLevels +{ + TxnLog& log; + TypeLevel minLevel; + + explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) + : log(log) + , minLevel(minLevel) + {} + + template + void promote(TID ty, T* t) + { + LUAU_ASSERT(t); + if (minLevel.subsumesStrict(t->level)) + { + log(ty); + t->level = minLevel; + } + } + + template + void cycle(TID) {} + + template + bool operator()(TID, const T&) + { + return true; + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const FunctionTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack&) + { + promote(tp, getMutable(tp)); + return true; + } +}; + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, ptl, seen); +} + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, ptl, seen); +} + struct SkipCacheForType { SkipCacheForType(const DenseHashMap& skipCacheForType) @@ -127,6 +200,29 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +// Used for tagged union matching heuristic, returns first singleton type field +static std::optional> getTableMatchTag(TypeId type) +{ + LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); + + type = follow(type); + + if (auto ttv = get(type)) + { + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + else if (auto mttv = get(type)) + { + return getTableMatchTag(mttv->table); + } + + return std::nullopt; +} + Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) @@ -214,9 +310,11 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(superTy, subTy); + TypeLevel superLevel = l->level; + // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -226,7 +324,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { - if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, superLevel, subTy); + else if (auto rightLevel = getMutableLevel(subTy)) { if (!rightLevel->subsumes(l->level)) *rightLevel = l->level; @@ -240,6 +340,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (r) { + TypeLevel subLevel = r->level; + occursCheck(subTy, superTy); // Unification can't change the level of a generic. @@ -253,10 +355,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (!get(subTy)) { - if (auto leftLevel = getMutableLevel(superTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, subLevel, superTy); + + if (auto superLevel = getMutableLevel(superTy)) { - if (!leftLevel->subsumes(r->level)) - *leftLevel = r->level; + if (!superLevel->subsumes(r->level)) + { + log(superTy); + *superLevel = r->level; + } } log(subTy); @@ -327,7 +435,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (failed) { if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } @@ -338,28 +446,46 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; + size_t failedOptionCount = 0; + std::optional failedOption; + + bool foundHeuristic = false; size_t startIndex = 0; if (FFlag::LuauUnionHeuristic) { - bool found = false; - - const std::string* subName = getName(subTy); - if (subName) + if (const std::string* subName = getName(subTy)) { for (size_t i = 0; i < uv->options.size(); ++i) { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) { - found = true; + foundHeuristic = true; startIndex = i; break; } } } - if (!found && cacheEnabled) + if (FFlag::LuauExtendedUnionMismatchError) + { + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + } + + if (!foundHeuristic && cacheEnabled) { for (size_t i = 0; i < uv->options.size(); ++i) { @@ -390,15 +516,27 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { unificationTooComplex = e; } + else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + { + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; + } innerState.log.rollback(); } if (unificationTooComplex) + { errors.push_back(*unificationTooComplex); + } else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) + if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + else if (FFlag::LuauExtendedTypeMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -431,7 +569,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else { @@ -771,6 +909,10 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal if (superIter.good() && subIter.good()) { tryUnify_(*superIter, *subIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + superIter.advance(); subIter.advance(); continue; @@ -853,13 +995,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *superIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *subIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); subIter.advance(); } @@ -917,14 +1059,22 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } for (size_t i = 0; i < numGenerics; i++) @@ -936,13 +1086,49 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + if (FFlag::LuauExtendedFunctionMismatchError) + { + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); - ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + bool reported = !innerState.errors.empty(); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + if (!reported) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + } + } + else + { + ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + } log.concat(std::move(innerState.log)); } @@ -994,7 +1180,7 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance) + if (!FFlag::LuauTableSubtypingVariance2) return DEPRECATED_tryUnifyTables(left, right, isIntersection); TableTypeVar* lt = getMutable(left); @@ -1133,7 +1319,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(lt); + log(left); lt->props[name] = clone; } else if (variance == Covariant) @@ -1146,7 +1332,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } else if (lt->state == TableState::Free) { - log(lt); + log(left); lt->props[name] = prop; } else @@ -1176,7 +1362,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(rt); + log(right); rt->indexer = lt->indexer; } } @@ -1185,7 +1371,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // Symmetric if we are invariant if (lt->state == TableState::Unsealed || lt->state == TableState::Free) { - log(lt); + log(left); lt->indexer = rt->indexer; } } @@ -1241,15 +1427,15 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, result}}); } else - return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; @@ -1467,7 +1653,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else if (lt->indexer) { - innerState.tryUnify_(lt->indexer->indexType, singletonTypes.stringType); + innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); // We already try to unify properties in both tables. // Skip those and just look for the ones remaining and see if they fit into the indexer. for (const auto& [name, type] : rt->props) @@ -1636,7 +1822,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorRecoveryType()); + tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { @@ -1825,7 +2011,7 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) if (get(ty) || get(ty) || get(ty)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); @@ -1834,14 +2020,14 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorRecoveryType(); + const TypeId anyTy = getSingletonTypes().errorRecoveryType(); std::vector queue; @@ -1887,7 +2073,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryType(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; } @@ -1951,7 +2137,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); return; } @@ -2005,7 +2191,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s errors.push_back(*e); else if (!innerErrors.empty()) errors.push_back( - TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3d0d5b7e6..dd24f27cb 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) @@ -159,7 +158,7 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { std::vector hotcomments; - while (isComment(p.lexer.current()) || (FFlag::LuauCaptureBrokenCommentSpans && p.lexer.current().type == Lexeme::BrokenComment)) + while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) { const char* text = p.lexer.current().data; unsigned int length = p.lexer.current().length; @@ -2780,7 +2779,7 @@ const Lexeme& Parser::nextLexeme() const Lexeme& lexeme = lexer.next(/*skipComments*/ false); // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. // The parser will turn this into a proper syntax error. - if (FFlag::LuauCaptureBrokenCommentSpans && lexeme.type == Lexeme::BrokenComment) + if (lexeme.type == Lexeme::BrokenComment) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); if (isComment(lexeme)) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9230d80d0..aecb619a1 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -11,26 +11,33 @@ enum class ReportFormat { Default, - Luacheck + Luacheck, + Gnu, }; -static void report(ReportFormat format, const char* name, const Luau::Location& location, const char* type, const char* message) +static void report(ReportFormat format, const char* name, const Luau::Location& loc, const char* type, const char* message) { switch (format) { case ReportFormat::Default: - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, type, message); break; case ReportFormat::Luacheck: { // Note: luacheck's end column is inclusive but our end column is exclusive // In addition, luacheck doesn't support multi-line messages, so if the error is multiline we'll fake end column as 100 and hope for the best - int columnEnd = (location.begin.line == location.end.line) ? location.end.column : 100; + int columnEnd = (loc.begin.line == loc.end.line) ? loc.end.column : 100; - fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, columnEnd, type, message); + // Use stdout to match luacheck behavior + fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, columnEnd, type, message); break; } + + case ReportFormat::Gnu: + // Note: GNU end column is inclusive but our end column is exclusive + fprintf(stderr, "%s:%d.%d-%d.%d: %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, loc.end.line + 1, loc.end.column, type, message); + break; } } @@ -97,6 +104,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); + printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -201,6 +209,8 @@ int main(int argc, char** argv) if (strcmp(argv[i], "--formatter=plain") == 0) format = ReportFormat::Luacheck; + else if (strcmp(argv[i], "--formatter=gnu") == 0) + format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; } diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp new file mode 100644 index 000000000..254df3f03 --- /dev/null +++ b/CLI/Coverage.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Coverage.h" + +#include "lua.h" + +#include +#include + +struct Coverage +{ + lua_State* L = nullptr; + std::vector functions; +} gCoverage; + +void coverageInit(lua_State* L) +{ + gCoverage.L = lua_mainthread(L); +} + +bool coverageActive() +{ + return gCoverage.L != nullptr; +} + +void coverageTrack(lua_State* L, int funcindex) +{ + int ref = lua_ref(L, funcindex); + gCoverage.functions.push_back(ref); +} + +static void coverageCallback(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) +{ + FILE* f = static_cast(context); + + std::string name; + + if (depth == 0) + name = "
"; + else if (function) + name = std::string(function) + ":" + std::to_string(linedefined); + else + name = ":" + std::to_string(linedefined); + + fprintf(f, "FN:%d,%s\n", linedefined, name.c_str()); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + fprintf(f, "FNDA:%d,%s\n", hits[i], name.c_str()); + break; + } + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + fprintf(f, "DA:%d,%d\n", int(i), hits[i]); +} + +void coverageDump(const char* path) +{ + lua_State* L = gCoverage.L; + + FILE* f = fopen(path, "w"); + if (!f) + { + fprintf(stderr, "Error opening coverage %s\n", path); + return; + } + + fprintf(f, "TN:\n"); + + for (int fref: gCoverage.functions) + { + lua_getref(L, fref); + + lua_Debug ar = {}; + lua_getinfo(L, -1, "s", &ar); + + fprintf(f, "SF:%s\n", ar.short_src); + lua_getcoverage(L, -1, f, coverageCallback); + fprintf(f, "end_of_record\n"); + + lua_pop(L, 1); + } + + fclose(f); + + printf("Coverage dump written to %s (%d functions)\n", path, int(gCoverage.functions.size())); +} diff --git a/CLI/Coverage.h b/CLI/Coverage.h new file mode 100644 index 000000000..74be4e5c9 --- /dev/null +++ b/CLI/Coverage.h @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +void coverageInit(lua_State* L); +bool coverageActive(); + +void coverageTrack(lua_State* L, int funcindex); +void coverageDump(const char* path); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index b3c9557bb..cb993dfee 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -67,6 +67,10 @@ std::optional readFile(const std::string& name) if (read != size_t(length)) return std::nullopt; + // Skip first line if it's a shebang + if (length > 2 && result[0] == '#' && result[1] == '!') + result.erase(0, result.find('\n')); + return result; } diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index c6d15a7f2..30a171f0f 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -110,12 +110,12 @@ void profilerStop() gProfiler.thread.join(); } -void profilerDump(const char* name) +void profilerDump(const char* path) { - FILE* f = fopen(name, "wb"); + FILE* f = fopen(path, "wb"); if (!f) { - fprintf(stderr, "Error opening profile %s\n", name); + fprintf(stderr, "Error opening profile %s\n", path); return; } @@ -129,7 +129,7 @@ void profilerDump(const char* name) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", name, double(total) / 1e6, + printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); uint64_t totalgc = 0; diff --git a/CLI/Profiler.h b/CLI/Profiler.h index 0a407e476..67b1acfd0 100644 --- a/CLI/Profiler.h +++ b/CLI/Profiler.h @@ -5,4 +5,4 @@ struct lua_State; void profilerStart(lua_State* L, int frequency); void profilerStop(); -void profilerDump(const char* name); \ No newline at end of file +void profilerDump(const char* path); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 2cdd0062f..35c02f2c8 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -8,6 +8,7 @@ #include "FileUtils.h" #include "Profiler.h" +#include "Coverage.h" #include "linenoise.hpp" @@ -24,6 +25,16 @@ enum class CompileFormat Binary }; +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = 1; + result.debugLevel = 1; + result.coverageLevel = coverageActive() ? 2 : 0; + + return result; +} + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -32,7 +43,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); + std::string bytecode = Luau::compile(std::string(s, l), copts()); if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; @@ -79,9 +90,12 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(ML, -1); + int status = lua_resume(ML, L, 0); if (status == 0) @@ -149,7 +163,7 @@ static void setupState(lua_State* L) static std::string runCode(lua_State* L, const std::string& source) { - std::string bytecode = Luau::compile(source); + std::string bytecode = Luau::compile(source, copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { @@ -329,11 +343,14 @@ static bool runFile(const char* name, lua_State* GL) std::string chunkname = "=" + std::string(name); - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); int status = 0; if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(L, -1); + status = lua_resume(L, NULL, 0); } else @@ -437,6 +454,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); + printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -495,6 +513,7 @@ int main(int argc, char** argv) setupState(L); int profile = 0; + bool coverage = false; for (int i = 1; i < argc; ++i) { @@ -505,11 +524,16 @@ int main(int argc, char** argv) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + else if (strcmp(argv[i], "--coverage") == 0) + coverage = true; } if (profile) profilerStart(L, profile); + if (coverage) + coverageInit(L); + std::vector files = getSourceFiles(argc, argv); int failed = 0; @@ -523,6 +547,9 @@ int main(int argc, char** argv) profilerDump("profile.out"); } + if (coverage) + coverageDump("coverage.out"); + return failed; } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 2c1e85ff0..8f74ffedd 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -462,20 +461,17 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosures) + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) { - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); + int32_t cid = bytecode.addConstantClosure(f->id); - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - shared = true; - } + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + shared = true; } } diff --git a/Makefile b/Makefile index 15c7ff7a4..b144cac60 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ TESTS_SOURCES=$(wildcard tests/*.cpp) TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -128,10 +128,10 @@ luau-size: luau # executable target aliases luau: $(REPL_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ luau-analyze: $(ANALYZE_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) diff --git a/Sources.cmake b/Sources.cmake index 57df9b91e..14834b3a5 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -133,6 +133,7 @@ target_sources(Luau.VM PRIVATE VM/src/ltable.cpp VM/src/ltablib.cpp VM/src/ltm.cpp + VM/src/ludata.cpp VM/src/lutf8lib.cpp VM/src/lvmexecute.cpp VM/src/lvmload.cpp @@ -152,12 +153,15 @@ target_sources(Luau.VM PRIVATE VM/src/lstring.h VM/src/ltable.h VM/src/ltm.h + VM/src/ludata.h VM/src/lvm.h ) if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp CLI/FileUtils.h CLI/FileUtils.cpp CLI/Profiler.h diff --git a/VM/include/lua.h b/VM/include/lua.h index 7078acd0f..55902160c 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -334,6 +334,10 @@ LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API void lua_singlestep(lua_State* L, int enabled); LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); +typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size); + +LUA_API void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback); + /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 76043b9cd..a65b03253 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -8,6 +8,7 @@ #include "lfunc.h" #include "lgc.h" #include "ldo.h" +#include "ludata.h" #include "lvm.h" #include "lnumutils.h" @@ -43,36 +44,30 @@ static Table* getcurrenv(lua_State* L) } } -static LUAU_NOINLINE TValue* index2adrslow(lua_State* L, int idx) +static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) { - api_check(L, idx <= 0); - if (idx > LUA_REGISTRYINDEX) + api_check(L, lua_ispseudo(idx)); + switch (idx) + { /* pseudo-indices */ + case LUA_REGISTRYINDEX: + return registry(L); + case LUA_ENVIRONINDEX: { - api_check(L, idx != 0 && -idx <= L->top - L->base); - return L->top + idx; + sethvalue(L, &L->env, getcurrenv(L)); + return &L->env; + } + case LUA_GLOBALSINDEX: + return gt(L); + default: + { + Closure* func = curr_func(L); + idx = LUA_GLOBALSINDEX - idx; + return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); + } } - else - switch (idx) - { /* pseudo-indices */ - case LUA_REGISTRYINDEX: - return registry(L); - case LUA_ENVIRONINDEX: - { - sethvalue(L, &L->env, getcurrenv(L)); - return &L->env; - } - case LUA_GLOBALSINDEX: - return gt(L); - default: - { - Closure* func = curr_func(L); - idx = LUA_GLOBALSINDEX - idx; - return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); - } - } } -static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) +static LUAU_FORCEINLINE TValue* index2addr(lua_State* L, int idx) { if (idx > 0) { @@ -83,15 +78,20 @@ static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) else return o; } + else if (idx > LUA_REGISTRYINDEX) + { + api_check(L, idx != 0 && -idx <= L->top - L->base); + return L->top + idx; + } else { - return index2adrslow(L, idx); + return pseudo2addr(L, idx); } } const TValue* luaA_toobject(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); return (p == luaO_nilobject) ? NULL : p; } @@ -145,7 +145,7 @@ void lua_xpush(lua_State* from, lua_State* to, int idx) { api_check(from, from->global == to->global); luaC_checkthreadsleep(to); - setobj2s(to, to->top, index2adr(from, idx)); + setobj2s(to, to->top, index2addr(from, idx)); api_incr_top(to); return; } @@ -202,7 +202,7 @@ void lua_settop(lua_State* L, int idx) void lua_remove(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); while (++p < L->top) setobjs2s(L, p - 1, p); @@ -213,7 +213,7 @@ void lua_remove(lua_State* L, int idx) void lua_insert(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); for (StkId q = L->top; q > p; q--) setobjs2s(L, q, q - 1); @@ -228,7 +228,7 @@ void lua_replace(lua_State* L, int idx) luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { @@ -250,7 +250,7 @@ void lua_replace(lua_State* L, int idx) void lua_pushvalue(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); setobj2s(L, L->top, o); api_incr_top(L); return; @@ -262,7 +262,7 @@ void lua_pushvalue(lua_State* L, int idx) int lua_type(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (o == luaO_nilobject) ? LUA_TNONE : ttype(o); } @@ -273,20 +273,20 @@ const char* lua_typename(lua_State* L, int t) int lua_iscfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return iscfunction(o); } int lua_isLfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return isLfunction(o); } int lua_isnumber(lua_State* L, int idx) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return tonumber(o, &n); } @@ -298,14 +298,14 @@ int lua_isstring(lua_State* L, int idx) int lua_isuserdata(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return (ttisuserdata(o) || ttislightuserdata(o)); } int lua_rawequal(lua_State* L, int index1, int index2) { - StkId o1 = index2adr(L, index1); - StkId o2 = index2adr(L, index2); + StkId o1 = index2addr(L, index1); + StkId o2 = index2addr(L, index2); return (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaO_rawequalObj(o1, o2); } @@ -313,8 +313,8 @@ int lua_equal(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : equalobj(L, o1, o2); return i; } @@ -323,8 +323,8 @@ int lua_lessthan(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaV_lessthan(L, o1, o2); return i; } @@ -332,7 +332,7 @@ int lua_lessthan(lua_State* L, int index1, int index2) double lua_tonumberx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { if (isnum) @@ -350,7 +350,7 @@ double lua_tonumberx(lua_State* L, int idx, int* isnum) int lua_tointegerx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { int res; @@ -371,7 +371,7 @@ int lua_tointegerx(lua_State* L, int idx, int* isnum) unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { unsigned res; @@ -391,13 +391,13 @@ unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) int lua_toboolean(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return !l_isfalse(o); } const char* lua_tolstring(lua_State* L, int idx, size_t* len) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) { luaC_checkthreadsleep(L); @@ -408,7 +408,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) return NULL; } luaC_checkGC(L); - o = index2adr(L, idx); /* previous call may reallocate the stack */ + o = index2addr(L, idx); /* previous call may reallocate the stack */ } if (len != NULL) *len = tsvalue(o)->len; @@ -417,7 +417,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) const char* lua_tostringatom(lua_State* L, int idx, int* atom) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) return NULL; const TString* s = tsvalue(o); @@ -438,7 +438,7 @@ const char* lua_namecallatom(lua_State* L, int* atom) const float* lua_tovector(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisvector(o)) { return NULL; @@ -448,7 +448,7 @@ const float* lua_tovector(lua_State* L, int idx) int lua_objlen(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TSTRING: @@ -469,13 +469,13 @@ int lua_objlen(lua_State* L, int idx) lua_CFunction lua_tocfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); } void* lua_touserdata(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TUSERDATA: @@ -489,13 +489,13 @@ void* lua_touserdata(lua_State* L, int idx) void* lua_touserdatatagged(lua_State* L, int idx, int tag) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (ttisuserdata(o) && uvalue(o)->tag == tag) ? uvalue(o)->data : NULL; } int lua_userdatatag(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (ttisuserdata(o)) return uvalue(o)->tag; return -1; @@ -503,13 +503,13 @@ int lua_userdatatag(lua_State* L, int idx) lua_State* lua_tothread(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!ttisthread(o)) ? NULL : thvalue(o); } const void* lua_topointer(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TTABLE: @@ -657,7 +657,7 @@ int lua_pushthread(lua_State* L) void lua_gettable(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); return; @@ -666,7 +666,7 @@ void lua_gettable(lua_State* L, int idx) void lua_getfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -678,7 +678,7 @@ void lua_getfield(lua_State* L, int idx, const char* k) void lua_rawgetfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -690,7 +690,7 @@ void lua_rawgetfield(lua_State* L, int idx, const char* k) void lua_rawget(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); return; @@ -699,7 +699,7 @@ void lua_rawget(lua_State* L, int idx) void lua_rawgeti(lua_State* L, int idx, int n) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); api_incr_top(L); @@ -717,7 +717,7 @@ void lua_createtable(lua_State* L, int narray, int nrec) void lua_setreadonly(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); @@ -727,7 +727,7 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) int lua_getreadonly(lua_State* L, int objindex) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); int res = t->readonly; @@ -736,7 +736,7 @@ int lua_getreadonly(lua_State* L, int objindex) void lua_setsafeenv(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); t->safeenv = bool(enabled); @@ -748,7 +748,7 @@ int lua_getmetatable(lua_State* L, int objindex) const TValue* obj; Table* mt = NULL; int res; - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -775,7 +775,7 @@ int lua_getmetatable(lua_State* L, int objindex) void lua_getfenv(lua_State* L, int idx) { StkId o; - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -801,7 +801,7 @@ void lua_settable(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -813,7 +813,7 @@ void lua_setfield(lua_State* L, int idx, const char* k) StkId t; TValue key; api_checknelems(L, 1); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); @@ -825,7 +825,7 @@ void lua_rawset(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -839,7 +839,7 @@ void lua_rawseti(lua_State* L, int idx, int n) { StkId o; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -854,7 +854,7 @@ int lua_setmetatable(lua_State* L, int objindex) TValue* obj; Table* mt; api_checknelems(L, 1); - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); api_checkvalidindex(L, obj); if (ttisnil(L->top - 1)) mt = NULL; @@ -896,7 +896,7 @@ int lua_setfenv(lua_State* L, int idx) StkId o; int res = 1; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -987,7 +987,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) func = 0; else { - StkId o = index2adr(L, errfunc); + StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } @@ -1150,7 +1150,7 @@ l_noret lua_error(lua_State* L) int lua_next(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); int more = luaH_next(L, hvalue(t), L->top - 1); if (more) @@ -1187,7 +1187,7 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz, tag); + Udata* u = luaU_newudata(L, sz, tag); setuvalue(L, L->top, u); api_incr_top(L); return u->data; @@ -1197,7 +1197,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); + Udata* u = luaU_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); @@ -1232,7 +1232,7 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) { luaC_checkthreadsleep(L); TValue* val; - const char* name = aux_upvalue(index2adr(L, funcindex), n, &val); + const char* name = aux_upvalue(index2addr(L, funcindex), n, &val); if (name) { setobj2s(L, L->top, val); @@ -1246,7 +1246,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) const char* name; TValue* val; StkId fi; - fi = index2adr(L, funcindex); + fi = index2addr(L, funcindex); api_checknelems(L, 1); name = aux_upvalue(fi, n, &val); if (name) @@ -1270,7 +1270,7 @@ int lua_ref(lua_State* L, int idx) api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); if (!ttisnil(p)) { Table* reg = hvalue(registry(L)); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index d77f84ef9..9fe1885fb 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -370,6 +370,69 @@ void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } +static int getmaxline(Proto* p) +{ + int result = -1; + + for (int i = 0; i < p->sizecode; ++i) + { + int line = luaG_getline(p, i); + result = result < line ? line : result; + } + + for (int i = 0; i < p->sizep; ++i) + { + int psize = getmaxline(p->p[i]); + result = result < psize ? psize : result; + } + + return result; +} + +static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback) +{ + memset(buffer, -1, size * sizeof(int)); + + for (int i = 0; i < p->sizecode; ++i) + { + Instruction insn = p->code[i]; + if (LUAU_INSN_OP(insn) != LOP_COVERAGE) + continue; + + int line = luaG_getline(p, i); + int hits = LUAU_INSN_E(insn); + + LUAU_ASSERT(size_t(line) < size); + buffer[line] = buffer[line] < hits ? hits : buffer[line]; + } + + const char* debugname = p->debugname ? getstr(p->debugname) : NULL; + int linedefined = luaG_getline(p, 0); + + callback(context, debugname, linedefined, depth, buffer, size); + + for (int i = 0; i < p->sizep; ++i) + getcoverage(p->p[i], depth + 1, buffer, size, context, callback); +} + +void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback) +{ + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + Proto* p = clvalue(func)->l.p; + + size_t size = getmaxline(p) + 1; + if (size == 0) + return; + + int* buffer = luaM_newarray(L, size, int, 0); + + getcoverage(p, 0, buffer, size, context, callback); + + luaM_freearray(L, buffer, size, int, 0); +} + static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) { size_t size = strlen(data); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index ab416041e..7393fc74f 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -8,11 +8,10 @@ #include "lfunc.h" #include "lstring.h" #include "ldo.h" +#include "ludata.h" #include -LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) - LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -59,10 +58,6 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSpropagate: case GCSpropagateagain: g->gcstats.currcycle.marktime += seconds; - - // atomic step had to be performed during the switch and it's tracked separately - if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) - g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; case GCSatomic: g->gcstats.currcycle.atomictime += seconds; @@ -488,7 +483,7 @@ static void freeobj(lua_State* L, GCObject* o) luaS_free(L, gco2ts(o)); break; case LUA_TUSERDATA: - luaS_freeudata(L, gco2u(o)); + luaU_freeudata(L, gco2u(o)); break; default: LUAU_ASSERT(0); @@ -632,17 +627,9 @@ static size_t remarkupvals(global_State* g) static size_t atomic(lua_State* L) { global_State* g = L->global; - size_t work = 0; - - if (FFlag::LuauSeparateAtomic) - { - LUAU_ASSERT(g->gcstate == GCSatomic); - } - else - { - g->gcstate = GCSatomic; - } + LUAU_ASSERT(g->gcstate == GCSatomic); + size_t work = 0; /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ @@ -666,11 +653,6 @@ static size_t atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - if (!FFlag::LuauSeparateAtomic) - { - GC_INTERRUPT(GCSatomic); - } - return work; } @@ -716,22 +698,7 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSatomic; } break; } @@ -853,7 +820,7 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - ptrdiff_t lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -908,7 +875,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1049,7 +1016,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index a79e7b953..f6f7a878f 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -7,6 +7,7 @@ #include "ltable.h" #include "lfunc.h" #include "lstring.h" +#include "ludata.h" #include #include diff --git a/VM/src/lobject.h b/VM/src/lobject.h index ba040af6c..fd0a15b75 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -78,15 +78,7 @@ typedef struct lua_TValue #define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th) #define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv) -// beware bit magic: a value is false if it's nil or boolean false -// baseline implementation: (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) -// we'd like a branchless version of this which helps with performance, and a very fast version -// so our strategy is to always read the boolean value (not using bvalue(o) because that asserts when type isn't boolean) -// we then combine it with type to produce 0/1 as follows: -// - when type is nil (0), & makes the result 0 -// - when type is boolean (1), we effectively only look at the bottom bit, so result is 0 iff boolean value is 0 -// - when type is different, it must have some of the top bits set - we keep all top bits of boolean value so the result is non-0 -#define l_isfalse(o) (!(((o)->value.b | ~1) & ttype(o))) +#define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) /* ** for internal debug only diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 18ee1cda5..a9e90d17a 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -206,32 +206,3 @@ void luaS_free(lua_State* L, TString* ts) L->global->strt.nuse--; luaM_free(L, ts, sizestring(ts->len), ts->memcat); } - -Udata* luaS_newudata(lua_State* L, size_t s, int tag) -{ - if (s > INT_MAX - sizeof(Udata)) - luaM_toobig(L); - Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); - luaC_link(L, u, LUA_TUSERDATA); - u->len = int(s); - u->metatable = NULL; - LUAU_ASSERT(tag >= 0 && tag <= 255); - u->tag = uint8_t(tag); - return u; -} - -void luaS_freeudata(lua_State* L, Udata* u) -{ - LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); - - void (*dtor)(void*) = nullptr; - if (u->tag == UTAG_IDTOR) - memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); - else if (u->tag) - dtor = L->global->udatagc[u->tag]; - - if (dtor) - dtor(u->data); - - luaM_free(L, u, sizeudata(u->len), u->memcat); -} diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 612da28d5..3fd0bd39b 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -8,11 +8,7 @@ /* string size limit */ #define MAXSSIZE (1 << 30) -/* special tag value is used for user data with inline dtors */ -#define UTAG_IDTOR LUA_UTAG_LIMIT - #define sizestring(len) (offsetof(TString, data) + len + 1) -#define sizeudata(len) (offsetof(Udata, data) + len) #define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) #define luaS_newliteral(L, s) (luaS_newlstr(L, "" s, (sizeof(s) / sizeof(char)) - 1)) @@ -26,8 +22,5 @@ LUAI_FUNC void luaS_resize(lua_State* L, int newsize); LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); LUAI_FUNC void luaS_free(lua_State* L, TString* ts); -LUAI_FUNC Udata* luaS_newudata(lua_State* L, size_t s, int tag); -LUAI_FUNC void luaS_freeudata(lua_State* L, Udata* u); - LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp new file mode 100644 index 000000000..d180c388e --- /dev/null +++ b/VM/src/ludata.cpp @@ -0,0 +1,37 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ludata.h" + +#include "lgc.h" +#include "lmem.h" + +#include + +Udata* luaU_newudata(lua_State* L, size_t s, int tag) +{ + if (s > INT_MAX - sizeof(Udata)) + luaM_toobig(L); + Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + luaC_link(L, u, LUA_TUSERDATA); + u->len = int(s); + u->metatable = NULL; + LUAU_ASSERT(tag >= 0 && tag <= 255); + u->tag = uint8_t(tag); + return u; +} + +void luaU_freeudata(lua_State* L, Udata* u) +{ + LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); + + void (*dtor)(void*) = nullptr; + if (u->tag == UTAG_IDTOR) + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + else if (u->tag) + dtor = L->global->udatagc[u->tag]; + + if (dtor) + dtor(u->data); + + luaM_free(L, u, sizeudata(u->len), u->memcat); +} diff --git a/VM/src/ludata.h b/VM/src/ludata.h new file mode 100644 index 000000000..59cb85bd1 --- /dev/null +++ b/VM/src/ludata.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +/* special tag value is used for user data with inline dtors */ +#define UTAG_IDTOR LUA_UTAG_LIMIT + +#define sizeudata(len) (offsetof(Udata, data) + len) + +LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); +LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index bf8d493eb..cebeeb584 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -63,7 +63,8 @@ #define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) #define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) -#define VM_PATCH_C(pc, slot) ((uint8_t*)(pc))[3] = uint8_t(slot) +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) // NOTE: If debugging the Luau code, disable this macro to prevent timeouts from // occurring when tracing code in Visual Studio / XCode @@ -120,7 +121,7 @@ */ #if VM_USE_CGOTO #define VM_CASE(op) CASE_##op: -#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[*(uint8_t*)pc]) +#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[LUAU_INSN_OP(*pc)]) #define VM_CONTINUE(op) goto* kDispatchTable[uint8_t(op)] #else #define VM_CASE(op) case op: @@ -325,7 +326,7 @@ static void luau_execute(lua_State* L) // ... and singlestep logic :) if (SingleStep) { - if (L->global->cb.debugstep && !luau_skipstep(*(uint8_t*)pc)) + if (L->global->cb.debugstep && !luau_skipstep(LUAU_INSN_OP(*pc))) { VM_PROTECT(luau_callhook(L, L->global->cb.debugstep, NULL)); @@ -335,13 +336,12 @@ static void luau_execute(lua_State* L) } #if VM_USE_CGOTO - VM_CONTINUE(*(uint8_t*)pc); + VM_CONTINUE(LUAU_INSN_OP(*pc)); #endif } #if !VM_USE_CGOTO - // Note: this assumes that LUAU_INSN_OP() decodes the first byte (aka least significant byte in the little endian encoding) - size_t dispatchOp = *(uint8_t*)pc; + size_t dispatchOp = LUAU_INSN_OP(*pc); dispatchContinue: switch (dispatchOp) @@ -2577,7 +2577,7 @@ static void luau_execute(lua_State* L) // update hits with saturated add and patch the instruction in place hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; - ((uint32_t*)pc)[-1] = LOP_COVERAGE | (uint32_t(hits) << 8); + VM_PATCH_E(pc - 1, hits); VM_NEXT(); } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 740a4cfd2..5d802277a 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -53,7 +53,7 @@ const float* luaV_tovector(const TValue* obj) static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers @@ -74,7 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) { - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 60e4f61e4..c8f6b5dcd 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -451,15 +451,16 @@ function raytraceScene() end function arrayToCanvasCommands(pixels) - local s = 'Test\nvar pixels = ['; + local s = {}; + table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do - s = s .. "["; + table.insert(s, "["); for x = 0,size-1 do - s = s .. "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"; + table.insert(s, "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"); end - s = s .. "],"; + table.insert(s, "],"); end - s = s .. '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ + table.insert(s, '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ \n\ \n\ var size = ' .. size .. ';\n\ @@ -479,9 +480,9 @@ for (var y = 0; y < size; y++) {\n\ canvas.setFillColor(l[0], l[1], l[2], 1);\n\ canvas.fillRect(x, y, 1, 1);\n\ }\n\ -}'; +}'); - return s; + return table.concat(s); end testOutput = arrayToCanvasCommands(raytraceScene()); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 3b74a99e4..62a9999b0 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -513,8 +513,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comme TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check(R"( --[[ @1 )"); @@ -526,8 +524,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[[@1"); auto ac = autocomplete('1'); @@ -2625,4 +2621,55 @@ local a: A<(number, s@1> CHECK(ac.entryMap.count("string")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); + + check(R"( +local function foo1() return 1 end +local function foo2() return "1" end + +local function bar0() return "got" .. a end +local function bar1(a: number) return "got " .. a end +local function bar2(a: number, b: string) return "got " .. a .. b end + +local t = {} +function t:bar1(a: number) return "got " .. a end + +local r1 = bar0(@1) +local r2 = bar1(@2) +local r3 = bar2(@3) +local r4 = t:bar1(@4) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::None); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('2'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('3'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('4'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6ba39adab..95811b3f4 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,9 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauGenericSpecialGlobals) - using namespace Luau; static std::string compileFunction(const char* source, uint32_t id) @@ -74,20 +71,10 @@ TEST_CASE("BasicFunction") bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); - if (FFlag::LuauPreloadClosures) - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( DUPCLOSURE R0 K0 RETURN R0 0 )"); - } - else - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( -NEWCLOSURE R0 P0 -RETURN R0 0 -)"); - } CHECK_EQ("\n" + bcb.dumpFunction(0), R"( RETURN R1 1 @@ -2859,47 +2846,35 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosures) - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( DUPCLOSURE R0 K0 CAPTURE VAL R0 RETURN R0 0 )"); - // multi-level recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( + // multi-level recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( DUPCLOSURE R0 K0 CAPTURE UPVAL U0 RETURN R0 1 )"); - // multi-level recursive capture where function isn't top-level - // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + // multi-level recursive capture where function isn't top-level + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( local function foo() local function bar() return function() return bar() end end end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 )"); - } - else - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( -NEWCLOSURE R0 P0 -CAPTURE VAL R0 -RETURN R0 0 -)"); - } } TEST_CASE("OutOfLocals") @@ -3504,8 +3479,6 @@ local t = { TEST_CASE("ConstantClosure") { - ScopedFastFlag sff("LuauPreloadClosures", true); - // closures without upvalues are created when bytecode is loaded CHECK_EQ("\n" + compileFunction(R"( return function() end @@ -3570,8 +3543,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff1("LuauPreloadClosures", true); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b2aad3163..b055a38e4 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -123,8 +123,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; -static StateRef runConformance( - const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, lua_State* initialLuaState = nullptr) +static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, + lua_State* initialLuaState = nullptr, lua_CompileOptions* copts = nullptr) { std::string path = __FILE__; path.erase(path.find_last_of("\\/")); @@ -180,13 +180,8 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; // default - copts.debugLevel = 2; // for debugger tests - copts.vectorCtor = "vector"; // for vector tests - size_t bytecodeSize = 0; - char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + char* bytecode = luau_compile(source.data(), source.size(), copts, &bytecodeSize); int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); @@ -373,29 +368,37 @@ TEST_CASE("Vector") { ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector, "vector"); - lua_setglobal(L, "vector"); + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.vectorCtor = "vector"; + + runConformance( + "vector.lua", + [](lua_State* L) { + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); #if LUA_VECTOR_SIZE == 4 - lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); #else - lua_pushvector(L, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f); #endif - luaL_newmetatable(L, "vector"); + luaL_newmetatable(L, "vector"); - lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index, nullptr); + lua_settable(L, -3); - lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall, nullptr); + lua_settable(L, -3); - lua_setreadonly(L, -1, true); - lua_setmetatable(L, -2); - lua_pop(L, 1); - }); + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }, + nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -499,6 +502,10 @@ TEST_CASE("Debugger") breakhits = 0; interruptedthread = nullptr; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 2; + runConformance( "debugger.lua", [](lua_State* L) { @@ -614,7 +621,8 @@ TEST_CASE("Debugger") lua_resume(interruptedthread, nullptr, 0); interruptedthread = nullptr; } - }); + }, + nullptr, &copts); CHECK(breakhits == 10); // 2 hits per breakpoint } @@ -863,4 +871,46 @@ TEST_CASE("TagMethodError") }); } +TEST_CASE("Coverage") +{ + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.coverageLevel = 2; + + runConformance( + "coverage.lua", + [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) -> int { + luaL_argexpected(L, lua_isLfunction(L, 1), 1, "function"); + + lua_newtable(L); + lua_getcoverage(L, 1, L, [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) { + lua_State* L = static_cast(context); + + lua_newtable(L); + + lua_pushstring(L, function); + lua_setfield(L, -2, "name"); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + lua_pushinteger(L, hits[i]); + lua_rawseti(L, -2, int(i)); + } + + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + }); + + return 1; + }, + "getcoverage"); + lua_setglobal(L, "getcoverage"); + }, + nullptr, nullptr, &copts); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe6..e3993cc53 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 72d3a9a64..5abcb09ac 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2303,8 +2303,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; @@ -2319,8 +2317,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 091c2f012..71ff4e1b8 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -207,6 +207,36 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") CHECK_EQ("number", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + + CheckResult result = check(R"( + local a = 55 :: number? + local b = a :: number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + local a = 55 :: string + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Cannot cast 'number' into 'string' because the types are unrelated", toString(result.errors[0])); + CHECK_EQ("string", toString(requireType("a"))); +} + TEST_CASE_FIXTURE(Fixture, "type_annotations_inside_function_bodies") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1e2eae147..1d8135d4c 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixTonumberReturnType) + using namespace Luau; TEST_SUITE_BEGIN("BuiltinTests"); @@ -814,6 +816,30 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); } +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') + )"); + + if (FFlag::LuauFixTonumberReturnType) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); + } +} + +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index aba508918..b62044fa9 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -644,4 +646,42 @@ f(1, 2, 3) CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ( + toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index fe8e7ff90..688680c10 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -279,7 +279,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1085,4 +1085,19 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 5f95efd52..1621ef32f 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -374,4 +374,54 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = { tag = 'cat', cafood = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Good = { success: true, result: string } +type Bad = { success: false, error: string } +type Result = Good | Bad + +local a: Result = { success = false, result = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index cb72faaf4..3ea9b80c3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -275,7 +277,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local a = {} @@ -346,7 +348,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -369,7 +371,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local T = {} @@ -476,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -511,7 +513,7 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -771,7 +773,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } @@ -782,7 +784,7 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( function empty() return {} end @@ -1465,7 +1467,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1550,7 +1552,7 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} @@ -1937,7 +1939,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -1952,7 +1954,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1971,7 +1973,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1995,7 +1997,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -2015,11 +2017,22 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); + } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2027,7 +2040,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2048,7 +2061,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, {"LuauExtendedTypeMismatchError", true}, }; @@ -2076,7 +2089,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2092,4 +2105,18 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +{ + ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + + CheckResult result = check(R"( +local b +b = setmetatable({}, {__call = b}) +b() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e3222a410..ad9ea8276 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -2084,7 +2085,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) - return a + tonumber(b), a .. b + return a + (tonumber(b) :: number), a .. b end local n, s = add(2,"3") )"); @@ -2485,7 +2486,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") CheckResult result = check(R"( --!strict function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() end )"); @@ -3329,7 +3330,7 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable { CheckResult result = check(R"( local x - print((x == true and (x .. "y")) .. 1) + print((x == true and (x .. "y")) .. 1) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -4473,7 +4474,18 @@ f(function(a, b, c, ...) return a + b end) )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); + + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( @@ -4604,7 +4616,17 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } + else + { + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4799,4 +4821,211 @@ local ModuleA = require(game.A) CHECK_EQ("*unknown*", toString(*oty)); } +/* + * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. + * + * We had an issue here where the scope for the `if` block here would + * have an elevated TypeLevel even though there is no function nesting going on. + * This would result in a free typevar for the type of _ that was much higher than + * it should be. This type would be erroneously quantified in the definition of `aaa`. + * This in turn caused an ice when evaluating `_()` in the while loop. + */ +TEST_CASE_FIXTURE(Fixture, "free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") +{ + check(R"( + --!strict + if _ then + _[_], _ = nil + _() + end + + local aaa = function():typeof(_) return 1 end + + if aaa then + while _() do + end + end + )"); + + // No ice()? No problem. +} + +/* + * This is a bit elaborate. Bear with me. + * + * The type of _ becomes free with the first statement. With the second, we unify it with a function. + * + * At this point, it is important that the newly created fresh types of this new function type are promoted + * to the same level as the original free type. If we do not, they are incorrectly ascribed the level of the + * containing function. + * + * If this is allowed to happen, the final lambda erroneously quantifies the type of _ to something ridiculous + * just before we typecheck the invocation to _. + */ +TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") +{ + check(R"( + l0, _ = nil + + local function p() + _() + end + + a = _( + function():(typeof(p),typeof(_)) + end + )[nil] + )"); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value. 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +{ + ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; + + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 9f9a007f1..f55b46a40 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -214,4 +214,32 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } +TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +{ + ScopedFastFlag flags[] = { + {"LuauTableSubtypingVariance2", true}, + }; + // I am not sure how to make this happen in Luau code. + + TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = arena.addType(TableTypeVar{ + {{"prop", Property{getSingletonTypes().numberType}}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + const TableTypeVar* ttv = get(unsealedTable); + REQUIRE(ttv); + + state.tryUnify(unsealedTable, sealedTable); + + // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. + CHECK(!ttv->props.empty()); + + state.log.rollback(); + + CHECK(ttv->props.empty()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 48496b895..b095a0db0 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -462,4 +462,20 @@ local a: XYZ = { w = 4 } CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") +{ + ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } + +local a: X? = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); +} + TEST_SUITE_END(); diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua new file mode 100644 index 000000000..f899603f9 --- /dev/null +++ b/tests/conformance/coverage.lua @@ -0,0 +1,64 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing coverage") + +function foo() + local x = 1 + local y = 2 + assert(x + y) +end + +function bar() + local function one(x) + return x + end + + local two = function(x) + return x + end + + one(1) +end + +function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true +end + +foo() +c = getcoverage(foo) +assert(#c == 1) +assert(c[1].name == "foo") +assert(validate(c[1], {5, 6, 7}, {})) + +bar() +c = getcoverage(bar) +assert(#c == 3) +assert(c[1].name == "bar") +assert(validate(c[1], {11, 15, 19}, {})) +assert(c[2].name == "one") +assert(validate(c[2], {12}, {})) +assert(c[3].name == nil) +assert(validate(c[3], {}, {16})) + +return 'OK' From a9aa4faf24e6cea1ac0e33d0054a7328a35f9d4a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 14:08:56 -0800 Subject: [PATCH 11/14] Sync to upstream/release/508 This version isn't for release because we've skipped some internal numbers due to year-end schedule changes, but it's better to merge separately. --- Analysis/include/Luau/LValue.h | 63 +++++++ Analysis/include/Luau/Predicate.h | 34 +--- Analysis/include/Luau/TypeInfer.h | 1 + Analysis/src/{Predicate.cpp => LValue.cpp} | 88 ++++++++- Analysis/src/TypeInfer.cpp | 83 ++++++++- Analysis/src/TypeVar.cpp | 11 +- Analysis/src/Unifier.cpp | 121 ++++--------- Ast/src/Parser.cpp | 4 - Sources.cmake | 5 +- VM/src/lbaselib.cpp | 8 +- tests/Autocomplete.test.cpp | 6 +- tests/LValue.test.cpp | 198 +++++++++++++++++++++ tests/Predicate.test.cpp | 117 ------------ tests/Symbol.test.cpp | 33 +++- tests/Transpiler.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 12 +- tests/TypeInfer.classes.test.cpp | 2 - tests/TypeInfer.intersectionTypes.test.cpp | 4 - tests/TypeInfer.refinements.test.cpp | 37 +++- tests/TypeInfer.singletons.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 4 - tests/TypeInfer.test.cpp | 13 ++ tests/TypeInfer.unionTypes.test.cpp | 4 - tests/conformance/math.lua | 1 + 24 files changed, 569 insertions(+), 283 deletions(-) create mode 100644 Analysis/include/Luau/LValue.h rename Analysis/src/{Predicate.cpp => LValue.cpp} (50%) create mode 100644 tests/LValue.test.cpp delete mode 100644 tests/Predicate.test.cpp diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h new file mode 100644 index 000000000..8fd96f05a --- /dev/null +++ b/Analysis/include/Luau/LValue.h @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" +#include "Luau/Symbol.h" + +#include // TODO: Kill with LuauLValueAsKey. +#include +#include + +namespace Luau +{ + +struct TypeVar; +using TypeId = const TypeVar*; + +struct Field; +using LValue = Variant; + +struct Field +{ + std::shared_ptr parent; + std::string key; + + bool operator==(const Field& rhs) const; + bool operator!=(const Field& rhs) const; +}; + +struct LValueHasher +{ + size_t operator()(const LValue& lvalue) const; +}; + +const LValue* baseof(const LValue& lvalue); + +std::optional tryGetLValue(const class AstExpr& expr); + +// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +std::pair> getFullName(const LValue& lvalue); + +// Kill with LuauLValueAsKey. +std::string toString(const LValue& lvalue); + +template +const T* get(const LValue& lvalue) +{ + return get_if(&lvalue); +} + +using NEW_RefinementMap = std::unordered_map; +using DEPRECATED_RefinementMap = std::map; + +// Transient. Kill with LuauLValueAsKey. +struct RefinementMap +{ + NEW_RefinementMap NEW_refinements; + DEPRECATED_RefinementMap DEPRECATED_refinements; +}; + +void merge(RefinementMap& l, const RefinementMap& r, std::function f); +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index a5e8b6ae1..df93b4f49 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -1,12 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Variant.h" #include "Luau/Location.h" -#include "Luau/Symbol.h" +#include "Luau/LValue.h" +#include "Luau/Variant.h" -#include -#include #include namespace Luau @@ -15,34 +13,6 @@ namespace Luau struct TypeVar; using TypeId = const TypeVar*; -struct Field; -using LValue = Variant; - -struct Field -{ - std::shared_ptr parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait. - std::string key; -}; - -std::optional tryGetLValue(const class AstExpr& expr); - -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -std::pair> getFullName(const LValue& lvalue); - -std::string toString(const LValue& lvalue); - -template -const T* get(const LValue& lvalue) -{ - return get_if(&lvalue); -} - -// Key is a stringified encoding of an LValue. -using RefinementMap = std::map; - -void merge(RefinementMap& l, const RefinementMap& r, std::function f); -void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); - struct TruthyPredicate; struct IsAPredicate; struct TypeGuardPredicate; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 451976e48..862f50d79 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -350,6 +350,7 @@ struct TypeChecker private: std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); + std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/LValue.cpp similarity index 50% rename from Analysis/src/Predicate.cpp rename to Analysis/src/LValue.cpp index 7bd8001e3..da6804c6b 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/LValue.cpp @@ -1,11 +1,59 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Predicate.h" +#include "Luau/LValue.h" #include "Luau/Ast.h" +#include + +LUAU_FASTFLAG(LuauLValueAsKey) + namespace Luau { +bool Field::operator==(const Field& rhs) const +{ + LUAU_ASSERT(parent && rhs.parent); + return key == rhs.key && (parent == rhs.parent || *parent == *rhs.parent); +} + +bool Field::operator!=(const Field& rhs) const +{ + return !(*this == rhs); +} + +size_t LValueHasher::operator()(const LValue& lvalue) const +{ + // Most likely doesn't produce high quality hashes, but we're probably ok enough with it. + // When an evidence is shown that operator==(LValue) is used more often than it should, we can have a look at improving the hash quality. + size_t acc = 0; + size_t offset = 0; + + const LValue* current = &lvalue; + while (current) + { + if (auto field = get(*current)) + acc ^= (std::hash{}(field->key) << 1) >> ++offset; + else if (auto symbol = get(*current)) + acc ^= std::hash{}(*symbol) << 1; + else + LUAU_ASSERT(!"Hash not accumulated for this new LValue alternative."); + + current = baseof(*current); + } + + return acc; +} + +const LValue* baseof(const LValue& lvalue) +{ + if (auto field = get(lvalue)) + return field->parent.get(); + + auto symbol = get(lvalue); + LUAU_ASSERT(symbol); + return nullptr; // Base of root is null. +} + std::optional tryGetLValue(const AstExpr& node) { const AstExpr* expr = &node; @@ -38,15 +86,15 @@ std::pair> getFullName(const LValue& lvalue) while (auto field = get(*current)) { keys.push_back(field->key); - current = field->parent.get(); - if (!current) - LUAU_ASSERT(!"LValue root is a Field?"); + current = baseof(*current); } const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } +// Kill with LuauLValueAsKey. std::string toString(const LValue& lvalue) { auto [symbol, keys] = getFullName(lvalue); @@ -56,7 +104,18 @@ std::string toString(const LValue& lvalue) return s; } -void merge(RefinementMap& l, const RefinementMap& r, std::function f) +static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::function f) +{ + for (const auto& [k, a] : r) + { + if (auto it = l.find(k); it != l.end()) + l[k] = f(it->second, a); + else + l[k] = a; + } +} + +static void merge(DEPRECATED_RefinementMap& l, const DEPRECATED_RefinementMap& r, std::function f) { auto itL = l.begin(); auto itR = r.begin(); @@ -69,21 +128,32 @@ void merge(RefinementMap& l, const RefinementMap& r, std::functionfirst > k) + else if (itL->first < k) + ++itL; + else { l[k] = a; ++itR; } - else - ++itL; } l.insert(itR, r.end()); } +void merge(RefinementMap& l, const RefinementMap& r, std::function f) +{ + if (FFlag::LuauLValueAsKey) + return merge(l.NEW_refinements, r.NEW_refinements, f); + else + return merge(l.DEPRECATED_refinements, r.DEPRECATED_refinements, f); +} + void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) { - refis[toString(lvalue)] = ty; + if (FFlag::LuauLValueAsKey) + refis.NEW_refinements[lvalue] = ty; + else + refis.DEPRECATED_refinements[toString(lvalue)] = ty; } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index abbc2901b..e29b6ec65 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) @@ -1626,6 +1627,10 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + // Not needed when we normalize types. + if (FFlag::LuauLValueAsKey && get(follow(t))) + return t; + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) goodOptions.push_back(*ty); else @@ -4967,13 +4972,83 @@ std::pair, std::vector> TypeChecker::createGener std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { - std::string path = toString(lvalue); + if (!FFlag::LuauLValueAsKey) + return DEPRECATED_resolveLValue(scope, lvalue); + + // We want to be walking the Scope parents. + // We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back. + // For example: + // There exists an entry t.x. + // We are asked to look for t.x.y. + // We need to search in the provided Scope. Find t.x.y first. + // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. + // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. + const auto& [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + std::optional found; + + std::vector childKeys; + const LValue* currentLValue = &lvalue; + while (currentLValue) + { + if (auto it = currentScope->refinements.NEW_refinements.find(*currentLValue); it != currentScope->refinements.NEW_refinements.end()) + { + found = it->second; + break; + } + + childKeys.push_back(*currentLValue); + currentLValue = baseof(*currentLValue); + } + + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = *it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; + } + + // No entry for it at all. Can happen when LValue root is a global. + return std::nullopt; +} + +std::optional TypeChecker::DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue) +{ auto [symbol, keys] = getFullName(lvalue); ScopePtr currentScope = scope; while (currentScope) { - if (auto it = currentScope->refinements.find(path); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.DEPRECATED_refinements.find(toString(lvalue)); it != currentScope->refinements.DEPRECATED_refinements.end()) return it->second; // Should not be using scope->lookup. This is already recursive. @@ -5000,7 +5075,9 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) { - if (auto it = refis.find(toString(lvalue)); it != refis.end()) + if (auto it = refis.DEPRECATED_refinements.find(toString(lvalue)); it != refis.DEPRECATED_refinements.end()) + return it->second; + else if (auto it = refis.NEW_refinements.find(lvalue); it != refis.NEW_refinements.end()) return it->second; else return resolveLValue(scope, lvalue); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 571b13ca0..fb75aa02e 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -996,18 +996,19 @@ std::optional> magicFunctionFormat( std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); - const size_t dataOffset = 1; + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + dataOffset < params.size(); ++i) + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) { - Location location = expr.args.data[std::min(i, expr.args.size - 1)]->location; + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + dataOffset], location); + typechecker.unify(expected[i], params[i + paramOffset], location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error - const size_t actualParamSize = params.size() - dataOffset; + size_t actualParamSize = params.size() - paramOffset; if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index c5aab8562..43ea37e7b 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,9 +18,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) -LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) @@ -416,7 +414,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (!innerState.errors.empty()) { // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type)) + if (!firstFailedOption && !isNil(type)) firstFailedOption = {innerState.errors.front()}; failed = true; @@ -434,7 +432,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(*unificationTooComplex); else if (failed) { - if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) + if (firstFailedOption) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -536,49 +534,36 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) errors.push_back( TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); - else if (FFlag::LuauExtendedTypeMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } else if (const IntersectionTypeVar* uv = get(superTy)) { - if (FFlag::LuauExtendedTypeMismatchError) + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) { - std::optional unificationTooComplex; - std::optional firstFailedOption; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; - } - - log.concat(std::move(innerState.log)); + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; } - if (unificationTooComplex) - errors.push_back(*unificationTooComplex); - else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); - } - else - { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) - { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - } + log.concat(std::move(innerState.log)); } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else if (const IntersectionTypeVar* uv = get(subTy)) { @@ -626,10 +611,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(*unificationTooComplex); else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } else if (get(superTy) && get(subTy)) @@ -1241,10 +1223,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, r->second.type); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1261,10 +1240,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, rt->indexer->indexResultType); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1302,10 +1278,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, lt->indexer->indexResultType); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1723,18 +1696,11 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse innerState.tryUnify_(lhs->table, rhs->table); innerState.tryUnify_(lhs->metatable, rhs->metatable); - if (FFlag::LuauExtendedTypeMismatchError) - { - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); - else if (!innerState.errors.empty()) - errors.push_back( - TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); - } - else - { - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); - } + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); log.concat(std::move(innerState.log)); } @@ -1821,31 +1787,22 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { - if (FFlag::LuauExtendedClassMismatchError) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, classProp->type); + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, classProp->type); - checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) - { - log.concat(std::move(innerState.log)); - } - else - { - ok = false; - innerState.log.rollback(); - } + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); } else { - tryUnify_(prop.type, classProp->type); + ok = false; + innerState.log.rollback(); } } } @@ -2185,8 +2142,6 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) { - LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError); - if (auto e = hasUnificationTooComplex(innerErrors)) errors.push_back(*e); else if (!innerErrors.empty()) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index dd24f27cb..72f616497 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau { @@ -1368,9 +1367,6 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme parameterStart = lexer.current(); - if (!FFlag::LuauParseGenericFunctionTypeBegin) - begin = parameterStart; - expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; diff --git a/Sources.cmake b/Sources.cmake index 14834b3a5..a7153eb37 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -45,6 +45,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/IostreamHelpers.h Analysis/include/Luau/JsonEncoder.h Analysis/include/Luau/Linter.h + Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h @@ -80,8 +81,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/IostreamHelpers.cpp Analysis/src/JsonEncoder.cpp Analysis/src/Linter.cpp + Analysis/src/LValue.cpp Analysis/src/Module.cpp - Analysis/src/Predicate.cpp Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp @@ -194,10 +195,10 @@ if(TARGET Luau.UnitTest) tests/Frontend.test.cpp tests/JsonEncoder.test.cpp tests/Linter.test.cpp + tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Parser.test.cpp - tests/Predicate.test.cpp tests/RequireTracer.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 881c804db..988fd315e 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -36,12 +36,14 @@ static int luaB_tonumber(lua_State* L) int base = luaL_optinteger(L, 2, 10); if (base == 10) { /* standard conversion */ - luaL_checkany(L, 1); - if (lua_isnumber(L, 1)) + int isnum = 0; + double n = lua_tonumberx(L, 1, &isnum); + if (isnum) { - lua_pushnumber(L, lua_tonumber(L, 1)); + lua_pushnumber(L, n); return 1; } + luaL_checkany(L, 1); /* error if we don't have any argument */ } else { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 62a9999b0..8ca09c0e0 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1394,7 +1394,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types") check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); @@ -1422,7 +1422,7 @@ return target(bar1, b@1 check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); @@ -1918,7 +1918,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp new file mode 100644 index 000000000..8a092779c --- /dev/null +++ b/tests/LValue.test.cpp @@ -0,0 +1,198 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) +{ + Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { + // TODO: normalize here also. + std::unordered_set s; + + if (auto utv = get(follow(a))) + s.insert(begin(utv), end(utv)); + else + s.insert(a); + + if (auto utv = get(follow(b))) + s.insert(begin(utv), end(utv)); + else + s.insert(b); + + std::vector options(s.begin(), s.end()); + return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); + }); +} + +static LValue mkSymbol(const std::string& s) +{ + return Symbol{AstName{s.data()}}; +} + +TEST_SUITE_BEGIN("LValue"); + +TEST_CASE("Luau_merge_hashmap_order") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); +} + +TEST_CASE("Luau_merge_hashmap_order2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); +} + +TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + std::string d = "d"; + std::string e = "e"; + + RefinementMap m{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().numberType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + RefinementMap other{{ + {mkSymbol(c), getSingletonTypes().stringType}, + {mkSymbol(d), getSingletonTypes().numberType}, + {mkSymbol(e), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(5, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE(m.NEW_refinements.count(mkSymbol(d))); + REQUIRE(m.NEW_refinements.count(mkSymbol(e))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | string", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(d)])); + CHECK_EQ("boolean", toString(m.NEW_refinements[mkSymbol(e)])); +} + +TEST_CASE("hashing_lvalue_global_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + LValue t_x1{Field{std::make_shared(Symbol{AstName{t1.data()}}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + LValue t_x2{Field{std::make_shared(Symbol{AstName{t2.data()}}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_EQ(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + NEW_RefinementMap m; + m[t_x1] = getSingletonTypes().stringType; + m[t_x2] = getSingletonTypes().numberType; + + CHECK_EQ(1, m.size()); +} + +TEST_CASE("hashing_lvalue_local_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + AstLocal localt1{AstName{t1.data()}, Location(), nullptr, 0, 0, nullptr}; + LValue t_x1{Field{std::make_shared(Symbol{&localt1}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + AstLocal localt2{AstName{t2.data()}, Location(), &localt1, 0, 0, nullptr}; + LValue t_x2{Field{std::make_shared(Symbol{&localt2}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_NE(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + NEW_RefinementMap m; + m[t_x1] = getSingletonTypes().stringType; + m[t_x2] = getSingletonTypes().numberType; + + CHECK_EQ(2, m.size()); +} + +TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp deleted file mode 100644 index 7081693e2..000000000 --- a/tests/Predicate.test.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeInfer.h" - -#include "Fixture.h" -#include "ScopedFlags.h" - -#include "doctest.h" - -using namespace Luau; - -static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) -{ - Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { - // TODO: normalize here also. - std::unordered_set s; - - if (auto utv = get(follow(a))) - s.insert(begin(utv), end(utv)); - else - s.insert(a); - - if (auto utv = get(follow(b))) - s.insert(begin(utv), end(utv)); - else - s.insert(b); - - std::vector options(s.begin(), s.end()); - return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); - }); -} - -TEST_SUITE_BEGIN("Predicate"); - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") -{ - RefinementMap m{ - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") -{ - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") -{ - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.numberType}, - {"c", typeChecker.booleanType}, - }; - - RefinementMap other{ - {"c", typeChecker.stringType}, - {"d", typeChecker.numberType}, - {"e", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(5, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - REQUIRE(m.count("d")); - REQUIRE(m.count("e")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("number", toString(m["b"])); - CHECK_EQ("boolean | string", toString(m["c"])); - CHECK_EQ("number", toString(m["d"])); - CHECK_EQ("boolean", toString(m["e"])); -} - -TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index 44fe3a3c7..e7d2973b8 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; TEST_SUITE_BEGIN("SymbolTests"); -TEST_CASE("hashing") +TEST_CASE("hashing_globals") { std::string s1 = "name"; std::string s2 = "name"; @@ -31,10 +31,37 @@ TEST_CASE("hashing") CHECK_EQ(std::hash()(two), std::hash()(two)); std::unordered_map theMap; - theMap[AstName{s1.data()}] = 5; - theMap[AstName{s2.data()}] = 1; + theMap[n1] = 5; + theMap[n2] = 1; REQUIRE_EQ(1, theMap.size()); } +TEST_CASE("hashing_locals") +{ + std::string s1 = "name"; + std::string s2 = "name"; + + // These two names point to distinct memory areas. + AstLocal one{AstName{s1.data()}, Location(), nullptr, 0, 0, nullptr}; + AstLocal two{AstName{s2.data()}, Location(), &one, 0, 0, nullptr}; + + Symbol n1{&one}; + Symbol n2{&two}; + + CHECK(n1 == n1); + CHECK(n1 != n2); + CHECK(n2 == n2); + + CHECK_EQ(std::hash()(&one), std::hash()(&one)); + CHECK_NE(std::hash()(&one), std::hash()(&two)); + CHECK_EQ(std::hash()(&two), std::hash()(&two)); + + std::unordered_map theMap; + theMap[n1] = 5; + theMap[n2] = 1; + + REQUIRE_EQ(2, theMap.size()); +} + TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 327fa0bbd..47c3883c1 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -555,8 +555,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") { - ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true); - std::string code = R"( local function foo(a: T, ...: S...) return 1 end local f: (T, S...)->(number) = foo diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1d8135d4c..506279b94 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -798,13 +798,14 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) + string.format("%s%d%s", 1, "hello", true) )"); TypeId stringType = typeChecker.stringType; TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(6, result); CHECK_EQ(Location(Position{1, 26}, Position{1, 27}), result.errors[0].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[0].data); @@ -814,6 +815,15 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(Location(Position{1, 38}, Position{1, 42}), result.errors[2].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); + + CHECK_EQ(Location(Position{2, 32}, Position{2, 33}), result.errors[3].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[3].data); + + CHECK_EQ(Location(Position{2, 35}, Position{2, 42}), result.errors[4].location); + CHECK_EQ(TypeErrorData(TypeMismatch{numberType, stringType}), result.errors[4].data); + + CHECK_EQ(Location(Position{2, 44}, Position{2, 48}), result.errors[5].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 1ff23fe69..0283ae192 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -449,8 +449,6 @@ b.X = 2 -- real Vector2.X is also read-only TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") { - ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true}; - CheckResult result = check(R"( local function foo(v) return v.X :: number + string.len(v.Y) diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 893bc2b30..93c0baf6d 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -343,8 +343,6 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -363,8 +361,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 688680c10..503b613f4 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -280,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -1085,6 +1084,41 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + local foo: string? = "hi" + assert(foo) + local foo: number = 5 + print(foo:sub(1, 1)) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = {x: string | number} + local t: T? = {x = "hi"} + if t then + if type(t.x) == "string" then + local foo = t.x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; @@ -1092,6 +1126,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip CheckResult result = check(R"( type T = { [string]: { prop: number }? } local t: T = {} + if t["hello"] then local foo = t["hello"].prop end diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 1621ef32f..68dc1b4fa 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -202,7 +202,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauExtendedTypeMismatchError", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 3ea9b80c3..80f40407e 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1955,7 +1955,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( type A = { x: number, y: number } @@ -1974,7 +1973,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( type AS = { x: number, y: number } @@ -1998,7 +1996,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2062,7 +2059,6 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, - {"LuauExtendedTypeMismatchError", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index ad9ea8276..76324556a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5013,6 +5013,19 @@ caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } +TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + local function f(thing: any | string) + local foo = thing.SomeRandomKey + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") { ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index b095a0db0..2357869e9 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -425,8 +425,6 @@ y = x TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -446,8 +444,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index d5bca44f0..bfea0e1f1 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -26,6 +26,7 @@ function f(...) end end +assert(pcall(tonumber) == false) assert(tonumber{} == nil) assert(tonumber'+0.01' == 1/100 and tonumber'+.01' == 0.01 and tonumber'.01' == 0.01 and tonumber'-1.' == -1 and From 44ccd8282244228a3a3108cb8e5237c45b9d92c2 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 14:10:07 -0800 Subject: [PATCH 12/14] Sync to upstream/release/509 --- .gitignore | 18 +- Analysis/include/Luau/Frontend.h | 10 +- Analysis/include/Luau/TxnLog.h | 280 ++- Analysis/include/Luau/TypeInfer.h | 40 +- Analysis/include/Luau/TypePack.h | 8 + Analysis/include/Luau/TypeVar.h | 1 + Analysis/include/Luau/TypedAllocator.h | 23 +- Analysis/include/Luau/Unifier.h | 48 +- Analysis/src/Autocomplete.cpp | 33 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 7 +- Analysis/src/Frontend.cpp | 28 +- Analysis/src/Module.cpp | 2 +- Analysis/src/Quantify.cpp | 14 +- Analysis/src/ToString.cpp | 48 +- Analysis/src/TxnLog.cpp | 295 ++- Analysis/src/TypeInfer.cpp | 807 +++++-- Analysis/src/TypePack.cpp | 51 +- Analysis/src/TypeVar.cpp | 19 +- Analysis/src/TypedAllocator.cpp | 1 + Analysis/src/Unifier.cpp | 2250 +++++++++++++------ Ast/include/Luau/Common.h | 8 +- CLI/Analyze.cpp | 2 +- CLI/Repl.cpp | 46 +- CLI/Web.cpp | 24 +- Compiler/include/Luau/Bytecode.h | 1 + Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 69 +- Compiler/src/Compiler.cpp | 7 +- Sources.cmake | 1 + VM/include/luaconf.h | 4 - VM/src/lapi.cpp | 34 +- VM/src/laux.cpp | 38 +- VM/src/lbitlib.cpp | 8 - VM/src/ldebug.cpp | 17 +- VM/src/ldo.cpp | 29 +- VM/src/lgc.cpp | 2 - VM/src/lgcdebug.cpp | 7 +- VM/src/lnumprint.cpp | 375 ++++ VM/src/lnumutils.h | 7 +- VM/src/lobject.h | 1 + VM/src/ltable.cpp | 6 +- VM/src/lvmload.cpp | 28 +- VM/src/lvmutils.cpp | 7 +- fuzz/number.cpp | 35 + tests/AstQuery.test.cpp | 2 - tests/Autocomplete.test.cpp | 43 +- tests/Conformance.test.cpp | 31 +- tests/Fixture.cpp | 5 - tests/Fixture.h | 9 - tests/Frontend.test.cpp | 7 +- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.builtins.test.cpp | 37 +- tests/TypeInfer.generics.test.cpp | 2 - tests/TypeInfer.test.cpp | 2 - tests/TypeInfer.tryUnify.test.cpp | 54 +- tests/TypeInfer.typePacks.cpp | 14 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tests/TypeVar.test.cpp | 4 +- tests/conformance/debug.lua | 9 + tests/conformance/strconv.lua | 51 + tests/main.cpp | 23 +- tools/numprint.py | 82 + 62 files changed, 3827 insertions(+), 1301 deletions(-) create mode 100644 VM/src/lnumprint.cpp create mode 100644 fuzz/number.cpp create mode 100644 tests/conformance/strconv.lua create mode 100644 tools/numprint.py diff --git a/.gitignore b/.gitignore index fa11b45b5..5688dff52 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ -^build/ -^coverage/ -^fuzz/luau.pb.* -^crash-* -^default.prof* -^fuzz-* -^luau$ -/.vs +/build/ +/build[.-]*/ +/coverage/ +/.vs/ +/.vscode/ +/fuzz/luau.pb.* +/crash-* +/default.prof* +/fuzz-* +/luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 07a0296a2..1f64db30c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -68,7 +68,7 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, one in the regular mode, ond once in strict mode + // When true, we run typechecking twice, once in the regular mode, and once in strict mode // in order to get more precise type information (e.g. for autocomplete). bool typecheckTwice = false; }; @@ -109,18 +109,18 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); - CheckResult check(const ModuleName& name); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); + CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); /** Lint some code that has no associated DataModel object * * Since this source fragment has no name, we cannot cache its AST. Instead, * we return it to the caller to use as they wish. */ - std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); + std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); CheckResult check(const SourceModule& module); // OLD. TODO KILL - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 29988a3b9..dc45bebf4 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -1,7 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include +#include + #include "Luau/TypeVar.h" +#include "Luau/TypePack.h" LUAU_FASTFLAG(LuauShareTxnSeen); @@ -9,27 +13,28 @@ namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be -struct TxnLog +// Remove with LuauUseCommitTxnLog +struct DEPRECATED_TxnLog { - TxnLog() + DEPRECATED_TxnLog() : originalSeenSize(0) , ownedSeen() , sharedSeen(&ownedSeen) { } - explicit TxnLog(std::vector>* sharedSeen) + explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) : originalSeenSize(sharedSeen->size()) , ownedSeen() , sharedSeen(sharedSeen) { } - TxnLog(const TxnLog&) = delete; - TxnLog& operator=(const TxnLog&) = delete; + DEPRECATED_TxnLog(const DEPRECATED_TxnLog&) = delete; + DEPRECATED_TxnLog& operator=(const DEPRECATED_TxnLog&) = delete; - TxnLog(TxnLog&&) = default; - TxnLog& operator=(TxnLog&&) = default; + DEPRECATED_TxnLog(DEPRECATED_TxnLog&&) = default; + DEPRECATED_TxnLog& operator=(DEPRECATED_TxnLog&&) = default; void operator()(TypeId a); void operator()(TypePackId a); @@ -37,7 +42,7 @@ struct TxnLog void rollback(); - void concat(TxnLog rhs); + void concat(DEPRECATED_TxnLog rhs); bool haveSeen(TypeId lhs, TypeId rhs); void pushSeen(TypeId lhs, TypeId rhs); @@ -54,4 +59,263 @@ struct TxnLog std::vector>* sharedSeen; // shared with all the descendent logs }; +// Pending state for a TypeVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingType +{ + // The pending TypeVar state. + TypeVar pending; + + explicit PendingType(TypeVar state) + : pending(std::move(state)) + { + } +}; + +// Pending state for a TypePackVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingTypePack +{ + // The pending TypePackVar state. + TypePackVar pending; + + explicit PendingTypePack(TypePackVar state) + : pending(std::move(state)) + { + } +}; + +template +T* getMutable(PendingType* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +template +T* getMutable(PendingTypePack* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +// Log of what TypeIds we are rebinding, to be committed later. +struct TxnLog +{ + TxnLog() + : ownedSeen() + , sharedSeen(&ownedSeen) + { + } + + explicit TxnLog(TxnLog* parent) + : parent(parent) + { + if (parent) + { + sharedSeen = parent->sharedSeen; + } + else + { + sharedSeen = &ownedSeen; + } + } + + explicit TxnLog(std::vector>* sharedSeen) + : sharedSeen(sharedSeen) + { + } + + TxnLog(TxnLog* parent, std::vector>* sharedSeen) + : parent(parent) + , sharedSeen(sharedSeen) + { + } + + TxnLog(const TxnLog&) = delete; + TxnLog& operator=(const TxnLog&) = delete; + + TxnLog(TxnLog&&) = default; + TxnLog& operator=(TxnLog&&) = default; + + // Gets an empty TxnLog pointer. This is useful for constructs that + // take a TxnLog, like TypePackIterator - use the empty log if you + // don't have a TxnLog to give it. + static const TxnLog* empty(); + + // Joins another TxnLog onto this one. You should use std::move to avoid + // copying the rhs TxnLog. + // + // If both logs talk about the same type, pack, or table, the rhs takes + // priority. + void concat(TxnLog rhs); + + // Commits the TxnLog, rebinding all type pointers to their pending states. + // Clears the TxnLog afterwards. + void commit(); + + // Clears the TxnLog without committing any pending changes. + void clear(); + + // Computes an inverse of this TxnLog at the current time. + // This method should be called before commit is called in order to give an + // accurate result. Committing the inverse of a TxnLog will undo the changes + // made by commit, assuming the inverse log is accurate. + TxnLog inverse(); + + bool haveSeen(TypeId lhs, TypeId rhs) const; + void pushSeen(TypeId lhs, TypeId rhs); + void popSeen(TypeId lhs, TypeId rhs); + + // Queues a type for modification. The original type will not change until commit + // is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* queue(TypeId ty); + + // Queues a type pack for modification. The original type pack will not change + // until commit is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* queue(TypePackId tp); + + // Returns the pending state of a type, or nullptr if there isn't any. It is important + // to note that this pending state is not transitive: the pending state may reference + // non-pending types freely, so you may need to call pending multiple times to view the + // entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* pending(TypeId ty) const; + + // Returns the pending state of a type pack, or nullptr if there isn't any. It is + // important to note that this pending state is not transitive: the pending state may + // reference non-pending types freely, so you may need to call pending multiple times + // to view the entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* pending(TypePackId tp) const; + + // Queues a replacement of a type with another type. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* replace(TypeId ty, TypeVar replacement); + + // Queues a replacement of a type pack with another type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* replace(TypePackId tp, TypePackVar replacement); + + // Queues a replacement of a table type with another table type that is bound + // to a specific value. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* bindTable(TypeId ty, std::optional newBoundTo); + + // Queues a replacement of a type with a level with a duplicate of that type + // with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeLevel(TypeId ty, TypeLevel newLevel); + + // Queues a replacement of a type pack with a level with a duplicate of that + // type pack with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* changeLevel(TypePackId tp, TypeLevel newLevel); + + // Queues a replacement of a table type with another table type with a new + // indexer. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeIndexer(TypeId ty, std::optional indexer); + + // Returns the type level of the pending state of the type, or the level of that + // type, if no pending state exists. If the type doesn't have a notion of a level, + // returns nullopt. If the pending state doesn't have a notion of a level, but the + // original state does, returns nullopt. + std::optional getLevel(TypeId ty) const; + + // Follows a type, accounting for pending type states. The returned type may have + // pending state; you should use `pending` or `get` to find out. + TypeId follow(TypeId ty); + + // Follows a type pack, accounting for pending type states. The returned type pack + // may have pending state; you should use `pending` or `get` to find out. + TypePackId follow(TypePackId tp) const; + + // Replaces a given type's state with a new variant. Returns the new pending state + // of that type. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingType* replace(TypeId ty, T replacement) + { + return replace(ty, TypeVar(replacement)); + } + + // Replaces a given type pack's state with a new variant. Returns the new + // pending state of that type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingTypePack* replace(TypePackId tp, T replacement) + { + return replace(tp, TypePackVar(replacement)); + } + + // Returns T if a given type or type pack is this variant, respecting the + // log's pending state. + // + // Do not retain this pointer; it has the potential to be invalidated when + // commit or clear is called. + template + T* getMutable(TID ty) const + { + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::getMutable(pendingTy); + + return Luau::getMutable(ty); + } + + // Returns whether a given type or type pack is a given state, respecting the + // log's pending state. + // + // This method will not assert if called on a BoundTypeVar or BoundTypePack. + template + bool is(TID ty) const + { + // We do not use getMutable here because this method can be called on + // BoundTypeVars, which triggers an assertion. + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::get_if(&pendingTy->pending.ty) != nullptr; + + return Luau::get_if(&ty->ty) != nullptr; + } + +private: + // unique_ptr is used to give us stable pointers across insertions into the + // map. Otherwise, it would be really easy to accidentally invalidate the + // pointers returned from queue/pending. + // + // We can't use a DenseHashMap here because we need a non-const iterator + // over the map when we concatenate. + std::unordered_map> typeVarChanges; + std::unordered_map> typePackChanges; + + TxnLog* parent = nullptr; + + // Owned version of sharedSeen. This should not be accessed directly in + // TxnLogs; use sharedSeen instead. This field exists because in the tree + // of TxnLogs, the root must own its seen set. In all descendant TxnLogs, + // this is an empty vector. + std::vector> ownedSeen; + +public: + // Used to avoid infinite recursion when types are cyclic. + // Shared with all the descendent TxnLogs. + std::vector>* sharedSeen; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 862f50d79..312283b05 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -198,32 +198,32 @@ struct TypeChecker */ TypeId anyIfNonstrict(TypeId ty) const; - /** Attempt to unify the types left and right. Treat any failures as type errors - * in the final typecheck report. + /** Attempt to unify the types. + * Treat any failures as type errors in the final typecheck report. */ - bool unify(TypeId left, TypeId right, const Location& location); - bool unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify(TypeId subTy, TypeId superTy, const Location& location); + bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); - /** Attempt to unify the types left and right. - * If this fails, and the right type can be instantiated, do so and try unification again. + /** Attempt to unify the types. + * If this fails, and the subTy type can be instantiated, do so and try unification again. */ - bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location); - void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state); + bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location); + void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state); - /** Attempt to unify left with right. + /** Attempt to unify. * If there are errors, undo everything and return the errors. * If there are no errors, commit and return an empty error vector. */ - ErrorVec tryUnify(TypeId left, TypeId right, const Location& location); - ErrorVec tryUnify(TypePackId left, TypePackId right, const Location& location); + template + ErrorVec tryUnify_(Id subTy, Id superTy, const Location& location); + ErrorVec tryUnify(TypeId subTy, TypeId superTy, const Location& location); + ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const Location& location); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy, const Location& location); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, const Location& location); - - // Variant that takes a preexisting 'seen' set. We need this in certain cases to avoid infinitely recursing - // into cyclic types. - ErrorVec canUnify(const std::vector>& seen, TypeId left, TypeId right, const Location& location); + template + ErrorVec canUnify_(Id subTy, Id superTy, const Location& location); + ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -237,12 +237,6 @@ struct TypeChecker std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); - template - ErrorVec tryUnify_(Id left, Id right, const Location& location); - - template - ErrorVec canUnify_(Id left, Id right, const Location& location); - public: /* * Convert monotype into a a polytype, by replacing any metavariables in descendant scopes diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index e72808da7..ca588ccb7 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -18,6 +18,8 @@ struct VariadicTypePack; struct TypePackVar; +struct TxnLog; + using TypePackId = const TypePackVar*; using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; @@ -84,6 +86,7 @@ struct TypePackIterator TypePackIterator() = default; explicit TypePackIterator(TypePackId tp); + TypePackIterator(TypePackId tp, const TxnLog* log); TypePackIterator& operator++(); TypePackIterator operator++(int); @@ -104,9 +107,13 @@ struct TypePackIterator TypePackId currentTypePack = nullptr; const TypePack* tp = nullptr; size_t currentIndex = 0; + + // Only used if LuauUseCommittingTxnLog is true. + const TxnLog* log; }; TypePackIterator begin(TypePackId tp); +TypePackIterator begin(TypePackId tp, TxnLog* log); TypePackIterator end(TypePackId tp); using SeenSet = std::set>; @@ -114,6 +121,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); +TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp); bool finite(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index f6829ec3e..d6e177142 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -453,6 +453,7 @@ bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real TypeId follow(TypeId t); +TypeId follow(TypeId t, std::function mapper); std::vector flattenIntersection(TypeId ty); diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index 0ded14890..64227e7c1 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauTypedAllocatorZeroStart) + namespace Luau { @@ -20,7 +22,10 @@ class TypedAllocator public: TypedAllocator() { - appendBlock(); + if (FFlag::LuauTypedAllocatorZeroStart) + currentBlockSize = kBlockSize; + else + appendBlock(); } ~TypedAllocator() @@ -59,12 +64,18 @@ class TypedAllocator bool empty() const { - return stuff.size() == 1 && currentBlockSize == 0; + if (FFlag::LuauTypedAllocatorZeroStart) + return stuff.empty(); + else + return stuff.size() == 1 && currentBlockSize == 0; } size_t size() const { - return kBlockSize * (stuff.size() - 1) + currentBlockSize; + if (FFlag::LuauTypedAllocatorZeroStart) + return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; + else + return kBlockSize * (stuff.size() - 1) + currentBlockSize; } void clear() @@ -72,7 +83,11 @@ class TypedAllocator if (frozen) unfreeze(); free(); - appendBlock(); + + if (FFlag::LuauTypedAllocatorZeroStart) + currentBlockSize = kBlockSize; + else + appendBlock(); } void freeze() diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7681b9662..a3be739a6 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -25,6 +25,7 @@ struct Unifier Mode mode; ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. + DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; ErrorVec errors; Location location; @@ -33,44 +34,45 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog = nullptr); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState); + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + ErrorVec canUnify(TypeId subTy, TypeId superTy); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); - /** Attempt to unify left with right. + /** Attempt to unify. * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); private: - void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); - void tryUnifyPrimitives(TypeId superTy, TypeId subTy); - void tryUnifySingletons(TypeId superTy, TypeId subTy); - void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void tryUnifyFreeTable(TypeId free, TypeId other); - void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); - void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); - void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); - void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyPrimitives(TypeId subTy, TypeId superTy); + void tryUnifySingletons(TypeId subTy, TypeId superTy); + void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyFreeTable(TypeId subTy, TypeId superTy); + void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection); + void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); - void cacheResult(TypeId superTy, TypeId subTy); + void cacheResult(TypeId subTy, TypeId superTy); public: - void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); private: - void tryUnify_(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); - void tryUnifyVariadics(TypePackId superTy, TypePackId subTy, bool reversed, int subOffset = 0); + void tryUnify_(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); + void tryUnifyVariadics(TypePackId subTy, TypePackId superTy, bool reversed, int subOffset = 0); - void tryUnifyWithAny(TypeId any, TypeId ty); - void tryUnifyWithAny(TypePackId any, TypePackId ty); + void tryUnifyWithAny(TypeId subTy, TypeId anyTy); + void tryUnifyWithAny(TypePackId subTy, TypePackId anyTp); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 4b583792c..67ebd0755 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,10 +12,12 @@ #include #include +LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); +LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -236,28 +238,31 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { ty = follow(ty); - auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { + auto canUnify = [&typeArena, &module](TypeId subTy, TypeId superTy) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - if (FFlag::LuauAutocompleteAvoidMutation) + if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) { SeenTypes seenTypes; SeenTypePacks seenTypePacks; CloneState cloneState; - expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, cloneState); - actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, cloneState); + superTy = clone(superTy, *typeArena, seenTypes, seenTypePacks, cloneState); + subTy = clone(subTy, *typeArena, seenTypes, seenTypePacks, cloneState); - auto errors = unifier.canUnify(expectedType, actualType); + auto errors = unifier.canUnify(subTy, superTy); return errors.empty(); } else { - unifier.tryUnify(expectedType, actualType); + unifier.tryUnify(subTy, superTy); bool ok = unifier.errors.empty(); - unifier.log.rollback(); + + if (!FFlag::LuauUseCommittingTxnLog) + unifier.DEPRECATED_log.rollback(); + return ok; } }; @@ -293,22 +298,22 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty() && canUnify(expectedType, retHead.front())) + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(expectedType, vtp->ty)) + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) return TypeCorrectKind::CorrectFunctionResult; } } - return canUnify(expectedType, ty) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } else { - if (canUnify(expectedType, ty)) + if (canUnify(ty, expectedType)) return TypeCorrectKind::Correct; // We also want to suggest functions that return compatible result @@ -320,13 +325,13 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto [retHead, retTail] = flatten(ftv->retType); if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; } return TypeCorrectKind::None; @@ -1319,7 +1324,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is()) + if (!nodes.back()->is() && (!FFlag::LuauCompleteBrokenStringParams || !nodes.back()->is())) { return std::nullopt; } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index d0afa7424..249825067 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -138,12 +138,7 @@ declare function gcinfo(): number -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - -- a userdata object is "roughly" the same as a sealed empty table - -- except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - -- another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - -- setmetatable. - -- FIXME: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - declare function newproxy(mt: boolean?): {} + declare function newproxy(mt: boolean?): any declare coroutine: { create: ((A...) -> R...) -> thread, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index e332f07d4..fe4b6529a 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -351,7 +351,7 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) { } -CheckResult Frontend::check(const ModuleName& name) +CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -372,6 +372,8 @@ CheckResult Frontend::check(const ModuleName& name) std::vector buildQueue; bool cycleDetected = parseGraph(buildQueue, checkResult, name); + FrontendOptions frontendOptions = optionOverride.value_or(options); + // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; @@ -411,31 +413,11 @@ CheckResult Frontend::check(const ModuleName& name) // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (options.typecheckTwice) + if (frontendOptions.typecheckTwice) { ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; } - else if (options.retainFullTypeGraphs && options.typecheckTwice && mode != Mode::Strict) - { - ModulePtr strictModule = typeChecker.check(sourceModule, Mode::Strict, environmentScope); - module->astTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astExpectedTypes.clear(); - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - - for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - } stats.timeCheck += getTimestamp() - timestamp; stats.filesStrict += mode == Mode::Strict; @@ -444,7 +426,7 @@ CheckResult Frontend::check(const ModuleName& name) if (module == nullptr) throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); - if (!options.retainFullTypeGraphs) + if (!frontendOptions.retainFullTypeGraphs) { // copyErrors needs to allocate into interfaceTypes as it copies // types out of internalTypes, so we unfreeze it here. diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index e1e53c971..cff85897c 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) namespace Luau { diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index c773e208b..04ebffc1b 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,8 +4,6 @@ #include "Luau/VisitTypeVar.h" -LUAU_FASTFLAGVARIABLE(LuauQuantifyVisitOnce, false) - namespace Luau { @@ -81,16 +79,8 @@ struct Quantifier void quantify(ModulePtr module, TypeId ty, TypeLevel level) { Quantifier q{std::move(module), level}; - - if (FFlag::LuauQuantifyVisitOnce) - { - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, q, seen); - } - else - { - visitTypeVar(ty, q); - } + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, q, seen); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index a6be53482..889dd6dc5 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) /* * Prefix generic typenames with gen- @@ -766,24 +765,12 @@ struct TypePackStringifier else state.emit(", "); - if (FFlag::LuauFunctionArgumentNameSize) + if (elemIndex < elemNames.size() && elemNames[elemIndex]) { - if (elemIndex < elemNames.size() && elemNames[elemIndex]) - { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); - } + state.emit(elemNames[elemIndex]->name); + state.emit(": "); } - else - { - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - if (!elemNames.empty() && elemNames[elemIndex]) - { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); - } - } elemIndex++; stringify(typeId); @@ -1151,38 +1138,19 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV s += ", "; first = false; - if (FFlag::LuauFunctionArgumentNameSize) + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) { - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - ++argNameIter; - } - else - { - s += "_: "; - } + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + ++argNameIter; } else { - // argNames is guaranteed to be equal to argTypes iff argNames is not empty. - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (!ftv.argNames.empty()) - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += "_: "; } s += toString_(*argPackIter); ++argPackIter; - - if (!FFlag::LuauFunctionArgumentNameSize) - { - if (!ftv.argNames.empty()) - { - LUAU_ASSERT(argNameIter != ftv.argNames.end()); - ++argNameIter; - } - } } if (argPackIter.tail()) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index f6a61581e..a46ac0c35 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -4,27 +4,34 @@ #include "Luau/TypePack.h" #include +#include + +LUAU_FASTFLAGVARIABLE(LuauUseCommittingTxnLog, false) namespace Luau { -void TxnLog::operator()(TypeId a) +void DEPRECATED_TxnLog::operator()(TypeId a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typeVarChanges.emplace_back(a, *a); } -void TxnLog::operator()(TypePackId a) +void DEPRECATED_TxnLog::operator()(TypePackId a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typePackChanges.emplace_back(a, *a); } -void TxnLog::operator()(TableTypeVar* a) +void DEPRECATED_TxnLog::operator()(TableTypeVar* a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); tableChanges.emplace_back(a, a->boundTo); } -void TxnLog::rollback() +void DEPRECATED_TxnLog::rollback() { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) std::swap(*asMutable(it->first), it->second); @@ -38,8 +45,9 @@ void TxnLog::rollback() sharedSeen->resize(originalSeenSize); } -void TxnLog::concat(TxnLog rhs) +void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); rhs.typeVarChanges.clear(); @@ -50,23 +58,298 @@ void TxnLog::concat(TxnLog rhs) rhs.tableChanges.clear(); } -bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) +bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } +void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + sharedSeen->push_back(sortedPair); +} + +void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); +} + +static const TxnLog emptyLog; + +const TxnLog* TxnLog::empty() +{ + return &emptyLog; +} + +void TxnLog::concat(TxnLog rhs) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (auto& [ty, rep] : rhs.typeVarChanges) + typeVarChanges[ty] = std::move(rep); + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + +void TxnLog::commit() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (auto& [ty, rep] : typeVarChanges) + *asMutable(ty) = rep.get()->pending; + + for (auto& [tp, rep] : typePackChanges) + *asMutable(tp) = rep.get()->pending; + + clear(); +} + +void TxnLog::clear() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + typeVarChanges.clear(); + typePackChanges.clear(); +} + +TxnLog TxnLog::inverse() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + TxnLog inversed(sharedSeen); + + for (auto& [ty, _rep] : typeVarChanges) + inversed.typeVarChanges[ty] = std::make_unique(*ty); + + for (auto& [tp, _rep] : typePackChanges) + inversed.typePackChanges[tp] = std::make_unique(*tp); + + return inversed; +} + +bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } + + if (parent) + { + return parent->haveSeen(lhs, rhs); + } + + return false; +} + void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } +PendingType* TxnLog::queue(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(!ty->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typeVarChanges[ty]; + if (!pending) + pending = std::make_unique(*ty); + + return pending.get(); +} + +PendingTypePack* TxnLog::queue(TypePackId tp) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(!tp->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typePackChanges[tp]; + if (!pending) + pending = std::make_unique(*tp); + + return pending.get(); +} + +PendingType* TxnLog::pending(TypeId ty) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) + return it->second.get(); + } + + return nullptr; +} + +PendingTypePack* TxnLog::pending(TypePackId tp) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) + return it->second.get(); + } + + return nullptr; +} + +PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + PendingType* newTy = queue(ty); + newTy->pending = replacement; + return newTy; +} + +PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + PendingTypePack* newTp = queue(tp); + newTp->pending = replacement; + return newTp; +} + +PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + ttv->boundTo = newBoundTo; + + return newTy; +} + +PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + + PendingType* newTy = queue(ty); + if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); + ttv->level = newLevel; + } + else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + + return newTy; +} + +PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(tp)); + + PendingTypePack* newTp = queue(tp); + if (FreeTypePack* ftp = Luau::getMutable(newTp)) + { + ftp->level = newLevel; + } + + return newTp; +} + +PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + ttv->indexer = indexer; + } + + return newTy; +} + +std::optional TxnLog::getLevel(TypeId ty) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + if (FreeTypeVar* ftv = getMutable(ty)) + return ftv->level; + else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) + return ttv->level; + else if (FunctionTypeVar* ftv = getMutable(ty)) + return ftv->level; + + return std::nullopt; +} + +TypeId TxnLog::follow(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + return Luau::follow(ty, [this](TypeId ty) { + PendingType* state = this->pending(ty); + + if (state == nullptr) + return ty; + + // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants + // that normally apply. This is safe because follow will only call get<> + // on the returned pointer. + return const_cast(&state->pending); + }); +} + +TypePackId TxnLog::follow(TypePackId tp) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + return Luau::follow(tp, [this](TypePackId tp) { + PendingTypePack* state = this->pending(tp); + + if (state == nullptr) + return tp; + + // Ugly: Fabricate a TypePackId that doesn't adhere to most of the + // invariants that normally apply. This is safe because follow will + // only call get<> on the returned pointer. + return const_cast(&state->pending); + }); +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e29b6ec65..1689a5c3d 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,15 +27,16 @@ LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) +LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) -LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) @@ -450,7 +451,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) ++subLevel; TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); - unify(leftType, funTy, fun->location); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) { @@ -556,21 +557,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) } } -ErrorVec TypeChecker::canUnify(TypeId left, TypeId right, const Location& location) +template +ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const Location& location) { - return canUnify_(left, right, location); + Unifier state = mkUnifier(location); + return state.canUnify(subTy, superTy); } -ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location& location) +ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const Location& location) { - return canUnify_(left, right, location); + return canUnify_(subTy, superTy, location); } -template -ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) +ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Location& location) { - Unifier state = mkUnifier(location); - return state.canUnify(superTy, subTy); + return canUnify_(subTy, superTy, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) @@ -619,7 +620,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) { - ErrorVec errors = tryUnify(scope->returnType, retPack, return_.location); + ErrorVec errors = tryUnify(retPack, scope->returnType, return_.location); if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); @@ -627,29 +628,39 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) return; } - unify(scope->returnType, retPack, return_.location, CountMismatch::Context::Return); + unify(retPack, scope->returnType, return_.location, CountMismatch::Context::Return); } -ErrorVec TypeChecker::tryUnify(TypeId left, TypeId right, const Location& location) +template +ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) { - return tryUnify_(left, right, location); + Unifier state = mkUnifier(location); + + if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + freeze(currentModule->internalTypes); + + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + unfreeze(currentModule->internalTypes); + + if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + + if (state.errors.empty() && FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + + return state.errors; } -ErrorVec TypeChecker::tryUnify(TypePackId left, TypePackId right, const Location& location) +ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const Location& location) { - return tryUnify_(left, right, location); + return tryUnify_(subTy, superTy, location); } -template -ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) +ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Location& location) { - Unifier state = mkUnifier(location); - state.tryUnify(left, right); - - if (!state.errors.empty()) - state.log.rollback(); - - return state.errors; + return tryUnify_(subTy, superTy, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) @@ -743,9 +754,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) - unify(left, anyType, loc); + unify(anyType, left, loc); else - unify(left, right, loc); + unify(right, left, loc); } } } @@ -760,7 +771,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); - unify(left, result, assign.location); + unify(result, left, assign.location); } void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) @@ -817,9 +828,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) Unifier state = mkUnifier(local.location); state.ctx = CountMismatch::Result; - state.tryUnify(variablePack, valuePack); + state.tryUnify(valuePack, variablePack); reportErrors(state.errors); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. // We also want to do this for 'local T = setmetatable(...)'. if (local.vars.size == 1 && local.values.size == 1) @@ -889,7 +903,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) TypeId loopVarType = numberType; if (expr.var->annotation) - unify(resolveType(scope, *expr.var->annotation), loopVarType, expr.location); + unify(loopVarType, resolveType(scope, *expr.var->annotation), expr.location); loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; @@ -899,11 +913,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) if (!expr.to) ice("Bad AstStatFor has no to expr"); - unify(loopVarType, checkExpr(loopScope, *expr.from).type, expr.from->location); - unify(loopVarType, checkExpr(loopScope, *expr.to).type, expr.to->location); + unify(checkExpr(loopScope, *expr.from).type, loopVarType, expr.from->location); + unify(checkExpr(loopScope, *expr.to).type, loopVarType, expr.to->location); if (expr.step) - unify(loopVarType, checkExpr(loopScope, *expr.step).type, expr.step->location); + unify(checkExpr(loopScope, *expr.step).type, loopVarType, expr.step->location); check(loopScope, *expr.body); } @@ -956,12 +970,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) if (get(callRetPack)) { iterTy = freshType(scope); - unify(addTypePack({{iterTy}, freshTypePack(scope)}), callRetPack, forin.location); + unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), forin.location); } else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorRecoveryType(scope), forin.location); + unify(errorRecoveryType(scope), var, forin.location); return check(loopScope, *forin.body); } @@ -982,7 +996,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) - unify(var, varTy, forin.location); + unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); @@ -1010,6 +1024,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) Unifier state = mkUnifier(firstValue->location); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + reportErrors(state.errors); } @@ -1024,10 +1041,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; TypePackId retPack = checkExprPack(scope, exprCall).type; - unify(varPack, retPack, forin.location); + unify(retPack, varPack, forin.location); } else - unify(varPack, iterFunc->retType, forin.location); + unify(iterFunc->retType, varPack, forin.location); check(loopScope, *forin.body); } @@ -1112,7 +1129,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - unify(leftType, ty, function.location); + unify(ty, leftType, function.location); if (FFlag::LuauUpdateFunctionNameBinding) { @@ -1242,7 +1259,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else if (auto mtv = getMutable(follow(ty))) mtv->syntheticName = name; - unify(bindingsMap[name].type, ty, typealias.location); + unify(ty, bindingsMap[name].type, typealias.location); } } @@ -1526,7 +1543,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); - unify(retPack, pack, expr.location); + unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } if (get(retPack)) @@ -1598,7 +1615,7 @@ std::optional TypeChecker::getIndexTypeFromType( return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(indexer->indexType, stringType, location); + tryUnify(stringType, indexer->indexType, location); return indexer->indexResultType; } else if (tableType->state == TableState::Free) @@ -1824,7 +1841,7 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) - unify(indexer->indexResultType, valueType, value->location); + unify(valueType, indexer->indexResultType, value->location); else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -1842,13 +1859,13 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(expectedProp.type, exprType, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type, k->location); if (errors.empty()) exprType = expectedProp.type; } else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) { - ErrorVec errors = tryUnify(expectedTable->indexer->indexResultType, exprType, k->location); + ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) exprType = expectedTable->indexer->indexResultType; } @@ -1863,8 +1880,8 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - unify(indexer->indexType, keyType, k->location); - unify(indexer->indexResultType, valueType, value->location); + unify(keyType, indexer->indexType, k->location); + unify(valueType, indexer->indexResultType, value->location); } else if (isNonstrictMode()) { @@ -1992,7 +2009,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) @@ -2006,7 +2026,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn return {errorRecoveryType(scope)}; } - reportErrors(tryUnify(numberType, operandType, expr.location)); + reportErrors(tryUnify(operandType, numberType, expr.location)); return {numberType}; } case AstExprUnary::Len: @@ -2072,7 +2092,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b { if (unifyFreeTypes && (get(a) || get(b))) { - if (unify(a, b, location)) + if (unify(b, a, location)) return a; return errorRecoveryType(anyType); @@ -2175,7 +2195,13 @@ TypeId TypeChecker::checkRelationalOperation( */ Unifier state = mkUnifier(expr.location); if (!isEquality) - state.tryUnify(lhsType, rhsType); + { + state.tryUnify(rhsType, lhsType); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + } + bool needsMetamethod = !isEquality; @@ -2216,13 +2242,16 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(ftv->retType, addTypePack({booleanType})); + state.tryUnify(addTypePack({booleanType}), ftv->retType); if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); return errorRecoveryType(booleanType); } + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); } } @@ -2230,7 +2259,10 @@ TypeId TypeChecker::checkRelationalOperation( TypeId actualFunctionType = addType(FunctionTypeVar(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( - instantiate(scope, *metamethod, expr.location), instantiate(scope, actualFunctionType, expr.location), /*isFunctionCall*/ true); + instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return booleanType; @@ -2323,7 +2355,7 @@ TypeId TypeChecker::checkBinaryOperation( } if (get(rhsType)) - unify(lhsType, rhsType, expr.location); + unify(rhsType, lhsType, expr.location); if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { @@ -2334,7 +2366,7 @@ TypeId TypeChecker::checkBinaryOperation( TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); bool hasErrors = !state.errors.empty(); @@ -2345,11 +2377,28 @@ TypeId TypeChecker::checkBinaryOperation( // so we loosen the argument types to see if that helps. TypePackId fallbackArguments = freshTypePack(scope); TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); - state.log.rollback(); state.errors.clear(); - state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); - if (!state.errors.empty()) - state.log.rollback(); + + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.clear(); + } + else + { + state.DEPRECATED_log.rollback(); + } + + state.tryUnify(actualFunctionType, fallbackFunctionType, /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + state.log.commit(); + else if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + } + + if (FFlag::LuauUseCommittingTxnLog && !hasErrors) + { + state.log.commit(); } TypeId retType = first(retTypePack).value_or(nilType); @@ -2377,8 +2426,8 @@ TypeId TypeChecker::checkBinaryOperation( switch (expr.op) { case AstExprBinary::Concat: - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), lhsType, expr.left->location)); - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.left->location)); + reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.right->location)); return stringType; case AstExprBinary::Add: case AstExprBinary::Sub: @@ -2386,8 +2435,8 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Div: case AstExprBinary::Mod: case AstExprBinary::Pow: - reportErrors(tryUnify(numberType, lhsType, expr.left->location)); - reportErrors(tryUnify(numberType, rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, numberType, expr.left->location)); + reportErrors(tryUnify(rhsType, numberType, expr.right->location)); return numberType; default: // These should have been handled with checkRelationalOperation @@ -2466,10 +2515,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy if (FFlag::LuauBidirectionalAsExpr) { // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (canUnify(result.type, annotationType, expr.location).empty()) + if (canUnify(annotationType, result.type, expr.location).empty()) return {annotationType, std::move(result.predicates)}; - if (canUnify(annotationType, result.type, expr.location).empty()) + if (canUnify(result.type, annotationType, expr.location).empty()) return {annotationType, std::move(result.predicates)}; reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); @@ -2477,7 +2526,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy } else { - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + ErrorVec errorVec = canUnify(annotationType, result.type, expr.location); reportErrors(errorVec); if (!errorVec.empty()) annotationType = errorRecoveryType(annotationType); @@ -2512,7 +2561,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); - unify(trueType.type, falseType.type, expr.location); + unify(falseType.type, trueType.type, expr.location); // TODO: normalize(UnionTypeVar{{trueType, falseType}}) // For now both trueType and falseType must be the same type. @@ -2607,14 +2656,18 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope else if (auto indexer = lhsTable->indexer) { Unifier state = mkUnifier(expr.location); - state.tryUnify(indexer->indexType, stringType); + state.tryUnify(stringType, indexer->indexType); TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + reportError(expr.location, UnknownProperty{lhs, name}); retType = errorRecoveryType(retType); } + else if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); return std::pair(retType, nullptr); } @@ -2713,7 +2766,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (exprTable->indexer) { const TableIndexer& indexer = *exprTable->indexer; - unify(indexer.indexType, indexType, expr.index->location); + unify(indexType, indexer.indexType, expr.index->location); return std::pair(indexer.indexResultType, nullptr); } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) @@ -3106,204 +3159,402 @@ void TypeChecker::checkArgumentList( * A function requires parameters. * To call a function, you supply arguments. */ - TypePackIterator argIter = begin(argPack); - TypePackIterator paramIter = begin(paramPack); + TypePackIterator argIter = begin(argPack, &state.log); + TypePackIterator paramIter = begin(paramPack, &state.log); TypePackIterator endIter = end(argPack); // Important subtlety: All end TypePackIterators are equivalent size_t paramIndex = 0; size_t minParams = getMinParameterCount(paramPack); - while (true) + if (FFlag::LuauUseCommittingTxnLog) { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - - if (argIter == endIter && paramIter == endIter) + while (true) { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); - - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - if (argTail) + if (argIter == endIter && paramIter == endIter) { - if (get(*argTail)) + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) { - if (paramTail) - state.tryUnify(*argTail, *paramTail); - else + if (state.log.getMutable(state.log.follow(*argTail))) { - state.log(*argTail); - *asMutable(*argTail) = TypePack{{}}; + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + state.log.replace(*argTail, TypePackVar(TypePack{{}})); } } - } - else if (paramTail) - { - // argTail is definitely empty - if (get(*paramTail)) + else if (paramTail) { - state.log(*paramTail); - *asMutable(*paramTail) = TypePack{{}}; + // argTail is definitely empty + if (state.log.getMutable(state.log.follow(*paramTail))) + state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } + + return; } + else if (argIter == endIter) + { + // Not enough arguments. - return; - } - else if (argIter == endIter) - { - // Not enough arguments. + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) + { + TypePackId tail = *argIter.tail(); + if (state.log.getMutable(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(errorRecoveryType(anyType), *paramIter); + ++paramIter; + } + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(vtp->ty, *paramIter); + ++paramIter; + } + + return; + } + else if (state.log.getMutable(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } + } - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = state.log.follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (state.log.getMutable(t)) + { + } // ok + else if (isNonstrictMode() && state.log.getMutable(t)) + { + } // ok + else + { + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; + } + } + else if (paramIter == endIter) { - TypePackId tail = *argIter.tail(); - if (get(tail)) + // too many parameters passed + if (!paramIter.tail()) { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) + while (argIter != endIter) { - state.tryUnify(*paramIter, errorRecoveryType(anyType)); - ++paramIter; + // The use of unify here is deliberate. We don't want this unification + // to be undoable. + unify(errorRecoveryType(scope), *argIter, state.location); + ++argIter; } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } - else if (auto vtp = get(tail)) + TypePackId tail = state.log.follow(*paramIter.tail()); + + if (state.log.getMutable(tail)) { - while (paramIter != endIter) + // Function is variadic. Ok. + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) { - state.tryUnify(*paramIter, vtp->ty); - ++paramIter; + Location location = state.location; + + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; } return; } - else if (get(tail)) + else if (state.log.getMutable(tail)) { + // Create a type pack out of the remaining argument types + // and unify it with the tail. std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) { - rest.push_back(*paramIter); - ++paramIter; + rest.push_back(*argIter); + ++argIter; } - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(tail, varPack); + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(varPack, tail); return; } - } - - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = follow(*paramIter); - if (isOptional(t)) + else if (state.log.getMutable(tail)) { - } // ok - else if (get(t)) - { - } // ok - else if (isNonstrictMode() && get(t)) - { - } // ok - else + state.log.replace(tail, TypePackVar(TypePack{{}})); + return; + } + else if (state.log.getMutable(tail)) { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } + } + else + { + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; ++paramIter; } + + ++paramIndex; } - else if (paramIter == endIter) + } + else + { + while (true) { - // too many parameters passed - if (!paramIter.tail()) + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + + if (argIter == endIter && paramIter == endIter) { - while (argIter != endIter) + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) { - unify(*argIter, errorRecoveryType(scope), state.location); - ++argIter; + if (get(*argTail)) + { + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + { + state.DEPRECATED_log(*argTail); + *asMutable(*argTail) = TypePack{{}}; + } + } + } + else if (paramTail) + { + // argTail is definitely empty + if (get(*paramTail)) + { + state.DEPRECATED_log(*paramTail); + *asMutable(*paramTail) = TypePack{{}}; + } } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = *paramIter.tail(); - if (get(tail)) - { - // Function is variadic. Ok. return; } - else if (auto vtp = get(tail)) + else if (argIter == endIter) { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) + // Not enough arguments. + + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) { - Location location = state.location; + TypePackId tail = *argIter.tail(); + if (get(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(*paramIter, errorRecoveryType(anyType)); + ++paramIter; + } + return; + } + else if (auto vtp = get(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(*paramIter, vtp->ty); + ++paramIter; + } - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; + return; + } + else if (get(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } - unify(vtp->ty, *argIter, location); - ++argIter; - ++argIndex; + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } } - return; + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (get(t)) + { + } // ok + else if (isNonstrictMode() && get(t)) + { + } // ok + else + { + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; + } } - else if (get(tail)) + else if (paramIter == endIter) { - // Create a type pack out of the remaining argument types - // and unify it with the tail. - std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) + // too many parameters passed + if (!paramIter.tail()) { - rest.push_back(*argIter); - ++argIter; + while (argIter != endIter) + { + unify(*argIter, errorRecoveryType(scope), state.location); + ++argIter; + } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; } + TypePackId tail = *paramIter.tail(); - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(tail, varPack); - return; - } - else if (get(tail)) - { - state.log(tail); - *asMutable(tail) = TypePack{}; + if (get(tail)) + { + // Function is variadic. Ok. + return; + } + else if (auto vtp = get(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) + { + Location location = state.location; - return; + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; + } + + return; + } + else if (get(tail)) + { + // Create a type pack out of the remaining argument types + // and unify it with the tail. + std::vector rest; + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) + { + rest.push_back(*argIter); + ++argIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(tail, varPack); + return; + } + else if (get(tail)) + { + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.replace(tail, TypePackVar(TypePack{{}})); + } + else + { + state.DEPRECATED_log(tail); + *asMutable(tail) = TypePack{}; + } + + return; + } + else if (get(tail)) + { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } } - else if (get(tail)) + else { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; + ++paramIter; } - } - else - { - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); - ++argIter; - ++paramIter; - } - ++paramIndex; + ++paramIndex; + } } } @@ -3475,7 +3726,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - unify(argPack, anyTypePack, expr.location); + unify(anyTypePack, argPack, expr.location); return {{anyTypePack}}; } @@ -3490,7 +3741,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - unify(r, fn, expr.location); + unify(fn, r, expr.location); return {{retPack}}; } @@ -3533,7 +3784,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!ftv) { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + unify(errorRecoveryTypePack(scope), retPack, expr.func->location); return {{errorRecoveryTypePack(retPack)}}; } @@ -3552,7 +3803,9 @@ std::optional> TypeChecker::checkCallOverload(const Scope checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); if (!state.errors.empty()) { - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + return {}; } @@ -3580,10 +3833,15 @@ std::optional> TypeChecker::checkCallOverload(const Scope overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); - state.log.rollback(); + + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); } else { + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND @@ -3640,6 +3898,9 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { + if (FFlag::LuauUseCommittingTxnLog) + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); // This is a little bit suspect: If this overload would work with a . replaced by a : // we eagerly assume that that's what you actually meant and we commit to it. @@ -3648,8 +3909,8 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); + else if (!FFlag::LuauUseCommittingTxnLog) + editedState.DEPRECATED_log.rollback(); } else if (ftv->hasSelf) { @@ -3671,6 +3932,9 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { + if (FFlag::LuauUseCommittingTxnLog) + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionRequiresSelf{}}); // This is a little bit suspect: If this overload would work with a : replaced by a . // we eagerly assume that that's what you actually meant and we commit to it. @@ -3679,8 +3943,8 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); + else if (!FFlag::LuauUseCommittingTxnLog) + editedState.DEPRECATED_log.rollback(); } } } @@ -3740,6 +4004,9 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } + if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + state.log.commit(); + if (i > 0) s += "; "; @@ -3748,7 +4015,8 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast s += toString(overload); - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); } if (overloadsThatMatchArgCount.size() == 0) @@ -3781,6 +4049,8 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L Unifier state = mkUnifier(location); + std::vector inverseLogs; + for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; @@ -3791,18 +4061,15 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (FFlag::LuauTailArgumentTypeInfo) + if (std::optional firstTy = first(typePack)) { - if (std::optional firstTy = first(typePack)) - { - if (!currentModule->astTypes.find(expr)) - currentModule->astTypes[expr] = follow(*firstTy); - } - - if (expectedType) - currentModule->astExpectedTypes[expr] = *expectedType; + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); } + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + tp->tail = typePack; } else @@ -3816,13 +4083,31 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L actualType = instantiate(scope, actualType, expr->location); if (expectedType) - state.tryUnify(*expectedType, actualType); + { + state.tryUnify(actualType, *expectedType); + + // Ugly: In future iterations of the loop, we might need the state of the unification we + // just performed. There's not a great way to pass that into checkExpr. Instead, we store + // the inverse of the current log, and commit it. When we're done, we'll commit all the + // inverses. This isn't optimal, and a better solution is welcome here. + if (FFlag::LuauUseCommittingTxnLog) + { + inverseLogs.push_back(state.log.inverse()); + state.log.commit(); + } + } tp->head.push_back(actualType); } } - state.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog) + { + for (TxnLog& log : inverseLogs) + log.commit(); + } + else + state.DEPRECATED_log.rollback(); return {pack, predicates}; } @@ -3884,7 +4169,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module TypePackId modulePack = module->getModuleScope()->returnType; - if (FFlag::LuauModuleRequireErrorPack && get(modulePack)) + if (get(modulePack)) return errorRecoveryType(scope); std::optional moduleType = first(modulePack); @@ -3917,72 +4202,94 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const return ty; } -bool TypeChecker::unify(TypeId left, TypeId right, const Location& location) +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) { Unifier state = mkUnifier(location); - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx) +bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx) { Unifier state = mkUnifier(location); state.ctx = ctx; - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) +bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location) { Unifier state = mkUnifier(location); - unifyWithInstantiationIfNeeded(scope, left, right, state); + unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) +void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state) { - if (!maybeGeneric(right)) + if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate - state.tryUnify(left, right, /*isFunctionCall*/ false); - else if (!maybeGeneric(left) && isGeneric(right)) + state.tryUnify(subTy, superTy, /*isFunctionCall*/ false); + else if (!maybeGeneric(superTy) && isGeneric(subTy)) { // Quick check to see if we definitely have to instantiate - TypeId instantiated = instantiate(scope, right, state.location); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + TypeId instantiated = instantiate(scope, subTy, state.location); + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } else { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. Unifier child = state.makeChildUnifier(); - child.tryUnify(left, right, /*isFunctionCall*/ false); + child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); if (!child.errors.empty()) { - TypeId instantiated = instantiate(scope, right, state.location); - if (right == instantiated) + TypeId instantiated = instantiate(scope, subTy, state.location); + if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - state.log.concat(std::move(child.log)); + if (FFlag::LuauUseCommittingTxnLog) + state.log.concat(std::move(child.log)); + else + state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); } else { - child.log.rollback(); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + if (!FFlag::LuauUseCommittingTxnLog) + child.DEPRECATED_log.rollback(); + + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } } else { - state.log.concat(std::move(child.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.concat(std::move(child.log)); + } + else + { + state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + } } } } @@ -4139,7 +4446,7 @@ TypePackId Quantification::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { if (const TableTypeVar* ttv = get(ty)) - return (ttv->state == TableState::Free); + return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); else if (get(ty)) return true; else @@ -4162,6 +4469,12 @@ TypeId Anyification::clean(TypeId ty) TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; + if (FFlag::LuauSealExports) + { + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; + } return addType(std::move(clone)); } else @@ -5194,8 +5507,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement // This by itself is not truly enough to determine that A is stronger than B or vice versa. // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. if (!optionIsSubtype && targetIsSubtype) @@ -5379,7 +5692,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa for (TypeId right : rhs) { // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) set.insert(left); } } @@ -5406,7 +5719,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp for (size_t i = 0; i < expectedLength; ++i) expectedPack->head.push_back(freshType(scope)); - unify(expectedTypePack, tp, location); + unify(tp, expectedTypePack, location); for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index d3221c732..b15548a8d 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,8 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/TxnLog.h" + #include +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + namespace Luau { @@ -35,14 +39,28 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) } TypePackIterator::TypePackIterator(TypePackId typePack) + : TypePackIterator(typePack, TxnLog::empty()) +{ +} + +TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) : currentTypePack(follow(typePack)) , tp(get(currentTypePack)) , currentIndex(0) + , log(log) { while (tp && tp->head.empty()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + if (FFlag::LuauUseCommittingTxnLog) + { + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + } + else + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } } } @@ -53,8 +71,17 @@ TypePackIterator& TypePackIterator::operator++() ++currentIndex; while (tp && currentIndex >= tp->head.size()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + if (FFlag::LuauUseCommittingTxnLog) + { + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + } + else + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } + currentIndex = 0; } @@ -95,6 +122,11 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } +TypePackIterator begin(TypePackId tp, TxnLog* log) +{ + return TypePackIterator{tp, log}; +} + TypePackIterator end(TypePackId tp) { return TypePackIterator{}; @@ -160,8 +192,15 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - auto advance = [](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(ty)) + return follow(tp, [](TypePackId t) { + return t; + }); +} + +TypePackId follow(TypePackId tp, std::function mapper) +{ + auto advance = [&mapper](TypePackId ty) -> std::optional { + if (const Unifiable::Bound* btv = get>(mapper(ty))) return btv->boundTo; else return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index fb75aa02e..4cab79c8a 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -31,17 +31,24 @@ std::optional> magicFunctionFormat( TypeId follow(TypeId t) { - auto advance = [](TypeId ty) -> std::optional { - if (auto btv = get>(ty)) + return follow(t, [](TypeId t) { + return t; + }); +} + +TypeId follow(TypeId t, std::function mapper) +{ + auto advance = [&mapper](TypeId ty) -> std::optional { + if (auto btv = get>(mapper(ty))) return btv->boundTo; - else if (auto ttv = get(ty)) + else if (auto ttv = get(mapper(ty))) return ttv->boundTo; else return std::nullopt; }; - auto force = [](TypeId ty) { - if (auto ltv = get_if(&ty->ty)) + auto force = [&mapper](TypeId ty) { + if (auto ltv = get_if(&mapper(ty)->ty)) { TypeId res = ltv->thunk(); if (get(res)) @@ -1004,7 +1011,7 @@ std::optional> magicFunctionFormat( { Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + paramOffset], location); + typechecker.unify(params[i + paramOffset], expected[i], location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index f037351e5..1f7ef8c25 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -20,6 +20,7 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE); #include LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAGVARIABLE(LuauTypedAllocatorZeroStart, false) namespace Luau { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 43ea37e7b..393a84a70 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,6 +13,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); +LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) @@ -29,27 +30,39 @@ namespace Luau struct PromoteTypeLevels { + DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; TypeLevel minLevel; - explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) - : log(log) + explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel) + : DEPRECATED_log(DEPRECATED_log) + , log(log) , minLevel(minLevel) - {} + { + } - template + template void promote(TID ty, T* t) { LUAU_ASSERT(t); if (minLevel.subsumesStrict(t->level)) { - log(ty); - t->level = minLevel; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeLevel(ty, minLevel); + } + else + { + DEPRECATED_log(ty); + t->level = minLevel; + } } } template - void cycle(TID) {} + void cycle(TID) + { + } template bool operator()(TID, const T&) @@ -59,39 +72,47 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FreeTypeVar&) { - promote(ty, getMutable(ty)); + // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauUseCommittingTxnLog && !log.is(ty)) + return true; + + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypeId ty, const FunctionTypeVar&) { - promote(ty, getMutable(ty)); + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } - bool operator()(TypeId ty, const TableTypeVar&) + bool operator()(TypeId ty, const TableTypeVar& ttv) { - promote(ty, getMutable(ty)); + if (ttv.state != TableState::Free && ttv.state != TableState::Generic) + return true; + + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypePackId tp, const FreeTypePack&) { - promote(tp, getMutable(tp)); + promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); return true; } }; -void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypeId ty) { - PromoteTypeLevels ptl{log, minLevel}; + PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } -void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypePackId tp) { - PromoteTypeLevels ptl{log, minLevel}; + PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } @@ -221,10 +242,12 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog) : types(types) , mode(mode) , globalScope(std::move(globalScope)) + , log(parentLog) , location(location) , variance(variance) , sharedState(sharedState) @@ -233,11 +256,12 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState) + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(sharedSeen) + , DEPRECATED_log(sharedSeen) + , log(parentLog, sharedSeen) , location(location) , variance(variance) , sharedState(sharedState) @@ -245,14 +269,14 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector< LUAU_ASSERT(sharedState.iceHandler); } -void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { sharedState.counters.iterationCount = 0; - tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(subTy, superTy, isFunctionCall, isIntersection); } -void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -264,55 +288,112 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } - superTy = follow(superTy); - subTy = follow(subTy); + if (FFlag::LuauUseCommittingTxnLog) + { + superTy = log.follow(superTy); + subTy = log.follow(subTy); + } + else + { + superTy = follow(superTy); + subTy = follow(subTy); + } if (superTy == subTy) return; - auto l = getMutable(superTy); - auto r = getMutable(subTy); + auto superFree = getMutable(superTy); + auto subFree = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superFree = log.getMutable(superTy); + subFree = log.getMutable(subTy); + } - if (l && r && l->level.subsumes(r->level)) + if (superFree && subFree && superFree->level.subsumes(subFree->level)) { occursCheck(subTy, superTy); // The occurrence check might have caused superTy no longer to be a free type - if (!get(subTy)) + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(subTy)); + else + occursFailed = bool(get(subTy)); + + if (!occursFailed) { - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + if (FFlag::LuauUseCommittingTxnLog) + { + log.replace(subTy, BoundTypeVar(superTy)); + } + else + { + DEPRECATED_log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } } return; } - else if (l && r) + else if (superFree && subFree) { - if (!FFlag::LuauErrorRecoveryType) - log(superTy); + if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) + { + DEPRECATED_log(superTy); + subFree->level = min(subFree->level, superFree->level); + } + occursCheck(superTy, subTy); - r->level = min(r->level, l->level); - // The occurrence check might have caused superTy no longer to be a free type - if (!FFlag::LuauErrorRecoveryType) - *asMutable(superTy) = BoundTypeVar(subTy); - else if (!get(superTy)) + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(superTy)); + else + occursFailed = bool(get(superTy)); + + if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) { - log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); + return; + } + + if (!occursFailed) + { + if (FFlag::LuauUseCommittingTxnLog) + { + if (superFree->level.subsumes(subFree->level)) + { + log.changeLevel(subTy, superFree->level); + } + + log.replace(superTy, BoundTypeVar(subTy)); + } + else + { + DEPRECATED_log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + subFree->level = min(subFree->level, superFree->level); + } } return; } - else if (l) + else if (superFree) { occursCheck(superTy, subTy); + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(superTy)); + else + occursFailed = bool(get(superTy)); - TypeLevel superLevel = l->level; + TypeLevel superLevel = superFree->level; // Unification can't change the level of a generic. - auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) + auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); + if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -320,63 +401,83 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } // The occurrence check might have caused superTy no longer to be a free type - if (!get(superTy)) + if (!occursFailed) { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(log, superLevel, subTy); - else if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!rightLevel->subsumes(l->level)) - *rightLevel = l->level; + promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + log.replace(superTy, BoundTypeVar(subTy)); } + else + { + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + else if (auto subLevel = getMutableLevel(subTy)) + { + if (!subLevel->subsumes(superFree->level)) + *subLevel = superFree->level; + } - log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); + DEPRECATED_log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } } return; } - else if (r) + else if (subFree) { - TypeLevel subLevel = r->level; + TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(subTy)); + else + occursFailed = bool(get(subTy)); // Unification can't change the level of a generic. - auto leftGeneric = get(superTy); - if (leftGeneric && !leftGeneric->level.subsumes(r->level)) + auto superGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy); + if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } - if (!get(subTy)) + if (!occursFailed) { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(log, subLevel, superTy); - - if (auto superLevel = getMutableLevel(superTy)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!superLevel->subsumes(r->level)) + promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + log.replace(subTy, BoundTypeVar(superTy)); + } + else + { + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + else if (auto superLevel = getMutableLevel(superTy)) { - log(superTy); - *superLevel = r->level; + if (!superLevel->subsumes(subFree->level)) + { + DEPRECATED_log(superTy); + *superLevel = subFree->level; + } } - } - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + DEPRECATED_log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } } return; } if (get(superTy) || get(superTy)) - return tryUnifyWithAny(superTy, subTy); + return tryUnifyWithAny(subTy, superTy); if (get(subTy) || get(subTy)) - return tryUnifyWithAny(subTy, superTy); + return tryUnifyWithAny(superTy, subTy); bool cacheEnabled = !isFunctionCall && !isIntersection; auto& cache = sharedState.cachedUnify; @@ -389,12 +490,22 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. - if (log.haveSeen(superTy, subTy)) - return; + if (FFlag::LuauUseCommittingTxnLog) + { + if (log.haveSeen(superTy, subTy)) + return; + + log.pushSeen(superTy, subTy); + } + else + { + if (DEPRECATED_log.haveSeen(superTy, subTy)) + return; - log.pushSeen(superTy, subTy); + DEPRECATED_log.pushSeen(superTy, subTy); + } - if (const UnionTypeVar* uv = get(subTy)) + if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { // A | B <: T if A <: T and B <: T bool failed = false; @@ -407,7 +518,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool for (TypeId type : uv->options) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type); + innerState.tryUnify_(type, superTy); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; @@ -420,10 +531,24 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool failed = true; } - if (i != count - 1) - innerState.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog) + { + if (i == count - 1) + { + log.concat(std::move(innerState.log)); + } + } else - log.concat(std::move(innerState.log)); + { + if (i != count - 1) + { + innerState.DEPRECATED_log.rollback(); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + } ++i; } @@ -438,7 +563,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } } - else if (const UnionTypeVar* uv = get(superTy)) + else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { // T <: A | B if T <: A or T <: B bool found = false; @@ -502,12 +627,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, isFunctionCall); + innerState.tryUnify_(subTy, type, isFunctionCall); if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -522,7 +651,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool failedOption = {innerState.errors.front()}; } - innerState.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -538,7 +668,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } - else if (const IntersectionTypeVar* uv = get(superTy)) + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { std::optional unificationTooComplex; std::optional firstFailedOption; @@ -547,7 +678,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; @@ -557,7 +688,10 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool firstFailedOption = {innerState.errors.front()}; } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } if (unificationTooComplex) @@ -565,7 +699,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (firstFailedOption) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } - else if (const IntersectionTypeVar* uv = get(subTy)) + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { // A & B <: T if T <: A or T <: B bool found = false; @@ -591,12 +726,15 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type, isFunctionCall); + innerState.tryUnify_(type, superTy, isFunctionCall); if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -604,7 +742,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool unificationTooComplex = e; } - innerState.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -614,44 +753,56 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } - else if (get(superTy) && get(subTy)) - tryUnifyPrimitives(superTy, subTy); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) - tryUnifySingletons(superTy, subTy); + else if (FFlag::LuauSingletonTypes && + ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) + tryUnifySingletons(subTy, superTy); - else if (get(superTy) && get(subTy)) - tryUnifyFunctions(superTy, subTy, isFunctionCall); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyFunctions(subTy, superTy, isFunctionCall); - else if (get(superTy) && get(subTy)) + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) { - tryUnifyTables(superTy, subTy, isIntersection); + tryUnifyTables(subTy, superTy, isIntersection); if (cacheEnabled && errors.empty()) - cacheResult(superTy, subTy); + cacheResult(subTy, superTy); } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if (get(superTy)) - tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); - else if (get(subTy)) - tryUnifyWithMetatable(subTy, superTy, /*reversed*/ true); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if (get(superTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if (get(subTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ true); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); - log.popSeen(superTy, subTy); + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superTy, subTy); + else + DEPRECATED_log.popSeen(superTy, subTy); } -void Unifier::cacheResult(TypeId superTy, TypeId subTy) +void Unifier::cacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); @@ -684,7 +835,7 @@ void Unifier::cacheResult(TypeId superTy, TypeId subTy) sharedState.cachedUnify.insert({subTy, superTy}); } -struct WeirdIter +struct DEPRECATED_WeirdIter { TypePackId packId; const TypePack* pack; @@ -692,7 +843,7 @@ struct WeirdIter bool growing; TypeLevel level; - WeirdIter(TypePackId packId) + DEPRECATED_WeirdIter(TypePackId packId) : packId(packId) , pack(get(packId)) , index(0) @@ -705,7 +856,7 @@ struct WeirdIter } } - WeirdIter(const WeirdIter&) = default; + DEPRECATED_WeirdIter(const DEPRECATED_WeirdIter&) = default; const TypeId& operator*() { @@ -756,34 +907,152 @@ struct WeirdIter } }; -ErrorVec Unifier::canUnify(TypeId superTy, TypeId subTy) +struct WeirdIter +{ + TypePackId packId; + TxnLog& log; + TypePack* pack; + size_t index; + bool growing; + TypeLevel level; + + WeirdIter(TypePackId packId, TxnLog& log) + : packId(packId) + , log(log) + , pack(log.getMutable(packId)) + , index(0) + , growing(false) + { + while (pack && pack->head.empty() && pack->tail) + { + packId = *pack->tail; + pack = log.getMutable(packId); + } + } + + WeirdIter(const WeirdIter&) = default; + + TypeId& operator*() + { + LUAU_ASSERT(good()); + return pack->head[index]; + } + + bool good() const + { + return pack != nullptr && index < pack->head.size(); + } + + bool advance() + { + if (!pack) + return good(); + + if (index < pack->head.size()) + ++index; + + if (growing || index < pack->head.size()) + return good(); + + if (pack->tail) + { + packId = log.follow(*pack->tail); + pack = log.getMutable(packId); + index = 0; + } + + return good(); + } + + bool canGrow() const + { + return nullptr != log.getMutable(packId); + } + + void grow(TypePackId newTail) + { + LUAU_ASSERT(canGrow()); + LUAU_ASSERT(log.getMutable(newTail)); + + level = log.getMutable(packId)->level; + log.replace(packId, Unifiable::Bound(newTail)); + packId = newTail; + pack = log.getMutable(newTail); + index = 0; + growing = true; + } + + void pushType(TypeId ty) + { + LUAU_ASSERT(pack); + PendingTypePack* pendingPack = log.queue(packId); + if (TypePack* pending = getMutable(pendingPack)) + { + pending->head.push_back(ty); + // We've potentially just replaced the TypePack* that we need to look + // in. We need to replace pack. + pack = pending; + } + else + { + LUAU_ASSERT(!"Pending state for this pack was not a TypePack"); + } + } +}; + +ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy); - s.log.rollback(); + s.tryUnify_(subTy, superTy); + + if (!FFlag::LuauUseCommittingTxnLog) + s.DEPRECATED_log.rollback(); + return s.errors; } -ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall) +ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy, isFunctionCall); - s.log.rollback(); + s.tryUnify_(subTy, superTy, isFunctionCall); + + if (!FFlag::LuauUseCommittingTxnLog) + s.DEPRECATED_log.rollback(); + return s.errors; } -void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { sharedState.counters.iterationCount = 0; - tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(subTp, superTp, isFunctionCall); +} + +static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; } /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ -void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -795,252 +1064,458 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal return; } - superTp = follow(superTp); - subTp = follow(subTp); - - while (auto r = get(subTp)) - { - if (r->head.empty() && r->tail) - subTp = follow(*r->tail); - else - break; - } - - while (auto l = get(superTp)) - { - if (l->head.empty() && l->tail) - superTp = follow(*l->tail); - else - break; - } - - if (superTp == subTp) - return; - - if (get(superTp)) + if (FFlag::LuauUseCommittingTxnLog) { - occursCheck(superTp, subTp); + superTp = log.follow(superTp); + subTp = log.follow(subTp); - // The occurrence check might have caused superTp no longer to be a free type - if (!get(superTp)) + while (auto tp = log.getMutable(subTp)) { - log(superTp); - *asMutable(superTp) = Unifiable::Bound(subTp); + if (tp->head.empty() && tp->tail) + subTp = log.follow(*tp->tail); + else + break; } - } - else if (get(subTp)) - { - occursCheck(subTp, superTp); - // The occurrence check might have caused superTp no longer to be a free type - if (!get(subTp)) + while (auto tp = log.getMutable(superTp)) { - log(subTp); - *asMutable(subTp) = Unifiable::Bound(superTp); + if (tp->head.empty() && tp->tail) + superTp = log.follow(*tp->tail); + else + break; } - } - - else if (get(superTp)) - tryUnifyWithAny(superTp, subTp); - else if (get(subTp)) - tryUnifyWithAny(subTp, superTp); - - else if (get(superTp)) - tryUnifyVariadics(superTp, subTp, false); - else if (get(subTp)) - tryUnifyVariadics(subTp, superTp, true); - - else if (get(superTp) && get(subTp)) - { - auto l = get(superTp); - auto r = get(subTp); - - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = flatten(superTp); - auto [subTypes, subTail] = flatten(subTp); - - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - - auto superIter = WeirdIter{superTp}; - auto subIter = WeirdIter{subTp}; - - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; - - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); - - int loopCount = 0; + if (superTp == subTp) + return; - do + if (log.getMutable(superTp)) { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); - - ++loopCount; - - if (superIter.good() && subIter.growing) - asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + occursCheck(superTp, subTp); - if (subIter.good() && superIter.growing) - asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); - - if (superIter.good() && subIter.good()) + if (!log.getMutable(superTp)) { - tryUnify_(*superIter, *subIter); - - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - - superIter.advance(); - subIter.advance(); - continue; + log.replace(superTp, Unifiable::Bound(subTp)); } + } + else if (log.getMutable(subTp)) + { + occursCheck(subTp, superTp); - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) + if (!log.getMutable(subTp)) { - const bool lFreeTail = l->tail && get(follow(*l->tail)) != nullptr; - const bool rFreeTail = r->tail && get(follow(*r->tail)) != nullptr; - if (lFreeTail && rFreeTail) - tryUnify_(*l->tail, *r->tail); - else if (lFreeTail) - tryUnify_(*l->tail, emptyTp); - else if (rFreeTail) - tryUnify_(*r->tail, emptyTp); - - break; + log.replace(subTp, Unifiable::Bound(superTp)); } + } + else if (log.getMutable(superTp)) + tryUnifyWithAny(subTp, superTp); + else if (log.getMutable(subTp)) + tryUnifyWithAny(superTp, subTp); + else if (log.getMutable(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (log.getMutable(subTp)) + tryUnifyVariadics(superTp, subTp, true); + else if (log.getMutable(superTp) && log.getMutable(subTp)) + { + auto superTpv = log.getMutable(superTp); + auto subTpv = log.getMutable(subTp); - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*superIter.pack->tail, *subIter.pack->tail); + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = logAwareFlatten(superTp, log); + auto [subTypes, subTail] = logAwareFlatten(subTp, log); - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + auto superIter = WeirdIter(superTp, log); + auto subIter = WeirdIter(subTp, log); - else + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) - { - superIter.advance(); - continue; - } + ++loopCount; - if (get(superIter.packId)) + if (superIter.good() && subIter.growing) { - tryUnifyVariadics(superIter.packId, subIter.packId, false, int(subIter.index)); - return; + subIter.pushType(mkFreshType(subIter.level)); } - if (get(subIter.packId)) + if (subIter.good() && superIter.growing) { - tryUnifyVariadics(subIter.packId, superIter.packId, true, int(superIter.index)); - return; + superIter.pushType(mkFreshType(superIter.level)); } - if (!isFunctionCall && subIter.good()) + if (superIter.good() && subIter.good()) { - // Sometimes it is ok to pass too many arguments - return; - } + tryUnify_(*subIter, *superIter); - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; - while (superIter.good()) - { - tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); superIter.advance(); + subIter.advance(); + continue; } - while (subIter.good()) + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) { - tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); - subIter.advance(); + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*subTpv->tail, *superTpv->tail); + else if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; } - return; - } + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + { + superIter.advance(); + continue; + } + + if (log.getMutable(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (log.getMutable(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } - } while (!noInfiniteGrowth); + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + superTp = follow(superTp); + subTp = follow(subTp); + + while (auto tp = get(subTp)) + { + if (tp->head.empty() && tp->tail) + subTp = follow(*tp->tail); + else + break; + } + + while (auto tp = get(superTp)) + { + if (tp->head.empty() && tp->tail) + superTp = follow(*tp->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (get(superTp)) + { + occursCheck(superTp, subTp); + + if (!get(superTp)) + { + DEPRECATED_log(superTp); + *asMutable(superTp) = Unifiable::Bound(subTp); + } + } + else if (get(subTp)) + { + occursCheck(subTp, superTp); + + if (!get(subTp)) + { + DEPRECATED_log(subTp); + *asMutable(subTp) = Unifiable::Bound(superTp); + } + } + + else if (get(superTp)) + tryUnifyWithAny(subTp, superTp); + + else if (get(subTp)) + tryUnifyWithAny(superTp, subTp); + + else if (get(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (get(subTp)) + tryUnifyVariadics(superTp, subTp, true); + + else if (get(superTp) && get(subTp)) + { + auto superTpv = get(superTp); + auto subTpv = get(subTp); + + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = flatten(superTp); + auto [subTypes, subTail] = flatten(subTp); + + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + + auto superIter = DEPRECATED_WeirdIter{superTp}; + auto subIter = DEPRECATED_WeirdIter{subTp}; + + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do + { + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); + + ++loopCount; + + if (superIter.good() && subIter.growing) + asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + + if (subIter.good() && superIter.growing) + asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); + + if (superIter.good() && subIter.good()) + { + tryUnify_(*subIter, *superIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + + superIter.advance(); + subIter.advance(); + continue; + } + + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*subTpv->tail, *superTpv->tail); + else if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; + } + + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) + { + superIter.advance(); + continue; + } + + if (get(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (get(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } + + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } } } -void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) +void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const PrimitiveTypeVar* rp = get(subTy); - if (!lp || !rp) + const PrimitiveTypeVar* superPrim = get(superTy); + const PrimitiveTypeVar* subPrim = get(subTy); + if (!superPrim || !subPrim) ice("passed non primitive types to unifyPrimitives"); - if (lp->type != rp->type) + if (superPrim->type != subPrim->type) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } -void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const SingletonTypeVar* ls = get(superTy); - const SingletonTypeVar* rs = get(subTy); + const PrimitiveTypeVar* superPrim = get(superTy); + const SingletonTypeVar* superSingleton = get(superTy); + const SingletonTypeVar* subSingleton = get(subTy); - if ((!lp && !ls) || !rs) + if ((!superPrim && !superSingleton) || !subSingleton) ice("passed non singleton/primitive types to unifySingletons"); - if (ls && *ls == *rs) + if (superSingleton && *superSingleton == *subSingleton) return; - if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) return; - if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } -void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) +void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) { - FunctionTypeVar* lf = getMutable(superTy); - FunctionTypeVar* rf = getMutable(subTy); - if (!lf || !rf) + FunctionTypeVar* superFunction = getMutable(superTy); + FunctionTypeVar* subFunction = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); + } + + if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); - size_t numGenerics = lf->generics.size(); - if (numGenerics != rf->generics.size()) + size_t numGenerics = superFunction->generics.size(); + if (numGenerics != subFunction->generics.size()) { - numGenerics = std::min(lf->generics.size(), rf->generics.size()); + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); if (FFlag::LuauExtendedFunctionMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); @@ -1048,10 +1523,10 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - size_t numGenericPacks = lf->genericPacks.size(); - if (numGenericPacks != rf->genericPacks.size()) + size_t numGenericPacks = superFunction->genericPacks.size(); + if (numGenericPacks != subFunction->genericPacks.size()) { - numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); if (FFlag::LuauExtendedFunctionMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); @@ -1060,7 +1535,12 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal } for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); + { + if (FFlag::LuauUseCommittingTxnLog) + log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + else + DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + } CountMismatch::Context context = ctx; @@ -1071,7 +1551,7 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (FFlag::LuauExtendedFunctionMismatchError) { innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); bool reported = !innerState.errors.empty(); @@ -1085,13 +1565,13 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + innerState.tryUnify_(subFunction->retType, superFunction->retType); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) errors.push_back(*e); - else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) errors.push_back( @@ -1104,38 +1584,70 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal else { ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + innerState.tryUnify_(subFunction->retType, superFunction->retType); checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + log.concat(std::move(innerState.log)); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } } else { ctx = CountMismatch::Arg; - tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(lf->retType, rf->retType); + tryUnify_(subFunction->retType, superFunction->retType); } - if (lf->definition && !rf->definition && !subTy->persistent) + if (FFlag::LuauUseCommittingTxnLog) { - rf->definition = lf->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; + } } - else if (!lf->definition && rf->definition && !superTy->persistent) + else { - lf->definition = rf->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + subFunction->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + superFunction->definition = subFunction->definition; + } } ctx = context; for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); + { + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superFunction->generics[i], subFunction->generics[i]); + else + DEPRECATED_log.popSeen(superFunction->generics[i], subFunction->generics[i]); + } } namespace @@ -1160,77 +1672,84 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { if (!FFlag::LuauTableSubtypingVariance2) - return DEPRECATED_tryUnifyTables(left, right, isIntersection); + return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); std::vector missingProperties; std::vector extraProperties; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer && subTable->state != TableState::Free) { - for (const auto& [propName, superProp] : lt->props) + for (const auto& [propName, superProp] : superTable->props) { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + auto subIter = subTable->props.find(propName); + if (subIter == subTable->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) missingProperties.push_back(propName); } if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && - lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && + superTable->state != TableState::Free) { - for (const auto& [propName, subProp] : rt->props) + for (const auto& [propName, subProp] : subTable->props) { - auto superIter = lt->props.find(propName); - if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + auto superIter = superTable->props.find(propName); + if (superIter == superTable->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) extraProperties.push_back(propName); } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } } - // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. - for (const auto& [name, prop] : lt->props) + for (const auto& [name, prop] : superTable->props) { - const auto& r = rt->props.find(name); - if (r != rt->props.end()) + const auto& r = subTable->props.find(name); + if (r != subTable->props.end()) { // TODO: read-only properties don't need invariance Resetter resetter{&variance}; variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, r->second.type); + innerState.tryUnify_(r->second.type, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (rt->indexer && isString(rt->indexer->indexType)) + else if (subTable->indexer && isString(subTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1238,37 +1757,55 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } else if (isOptional(prop.type) || get(follow(prop.type))) // TODO: this case is unsound, but without it our test suite fails. CLI-46031 // TODO: should isOptional(anyType) be true? { } - else if (rt->state == TableState::Free) + else if (subTable->state == TableState::Free) { - log(rt); - rt->props[name] = prop; + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* ttv = getMutable(pendingSub); + LUAU_ASSERT(ttv); + ttv->props[name] = prop; + } + else + { + DEPRECATED_log(subTy); + subTable->props[name] = prop; + } } else missingProperties.push_back(name); } - for (const auto& [name, prop] : rt->props) + for (const auto& [name, prop] : subTable->props) { - if (lt->props.count(name)) + if (superTable->props.count(name)) { // If both lt and rt contain the property, then // we're done since we already unified them above } - else if (lt->indexer && isString(lt->indexer->indexType)) + else if (superTable->indexer && isString(superTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1276,24 +1813,42 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (lt->state == TableState::Unsealed) + else if (superTable->state == TableState::Unsealed) { // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. // TODO: file a JIRA // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(left); - lt->props[name] = clone; + + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = clone; + } + else + { + DEPRECATED_log(superTy); + superTable->props[name] = clone; + } } else if (variance == Covariant) { @@ -1303,61 +1858,93 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: should isOptional(anyType) be true? { } - else if (lt->state == TableState::Free) + else if (superTable->state == TableState::Free) { - log(left); - lt->props[name] = prop; + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = prop; + } + else + { + DEPRECATED_log(superTy); + superTable->props[name] = prop; + } } else extraProperties.push_back(name); } // Unify indexers - if (lt->indexer && rt->indexer) + if (superTable->indexer && subTable->indexer) { // TODO: read-only indexers don't need invariance Resetter resetter{&variance}; variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*lt->indexer, *rt->indexer); - checkChildUnifierTypeMismatch(innerState.errors, left, right); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (lt->indexer) + else if (superTable->indexer) { - if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + if (subTable->state == TableState::Unsealed || subTable->state == TableState::Free) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(right); - rt->indexer = lt->indexer; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(subTy, superTable->indexer); + } + else + { + DEPRECATED_log(subTy); + subTable->indexer = superTable->indexer; + } } } - else if (rt->indexer && variance == Invariant) + else if (subTable->indexer && variance == Invariant) { // Symmetric if we are invariant - if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + if (superTable->state == TableState::Unsealed || superTable->state == TableState::Free) { - log(left); - lt->indexer = rt->indexer; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(superTy, subTable->indexer); + } + else + { + DEPRECATED_log(superTy); + superTable->indexer = subTable->indexer; + } } } if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } @@ -1369,18 +1956,32 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (lt->boundTo || rt->boundTo) - return tryUnify_(left, right); + if (superTable->boundTo || subTable->boundTo) + return tryUnify_(subTy, superTy); - if (lt->state == TableState::Free) + if (superTable->state == TableState::Free) { - log(lt); - lt->boundTo = right; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(superTable); + superTable->boundTo = subTy; + } } - else if (rt->state == TableState::Free) + else if (subTable->state == TableState::Free) { - log(rt); - rt->boundTo = left; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + DEPRECATED_log(subTy); + subTable->boundTo = superTy; + } } } @@ -1406,99 +2007,129 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } -void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); - if (lt->state == TableState::Sealed && rt->state == TableState::Sealed) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Unsealed) || - (lt->state == TableState::Unsealed && rt->state == TableState::Sealed)) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Generic) || - (lt->state == TableState::Generic && rt->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{left, right}}); - else if ((lt->state == TableState::Free) != (rt->state == TableState::Free)) // one table is free and the other is not + if (superTable->state == TableState::Sealed && subTable->state == TableState::Sealed) + return tryUnifySealedTables(subTy, superTy, isIntersection); + else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Unsealed) || + (superTable->state == TableState::Unsealed && subTable->state == TableState::Sealed)) + return tryUnifySealedTables(subTy, superTy, isIntersection); + else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || + (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not { - TypeId freeTypeId = rt->state == TableState::Free ? right : left; - TypeId otherTypeId = rt->state == TableState::Free ? left : right; + TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; + TypeId otherTypeId = subTable->state == TableState::Free ? superTy : subTy; - return tryUnifyFreeTable(freeTypeId, otherTypeId); + return tryUnifyFreeTable(otherTypeId, freeTypeId); } - else if (lt->state == TableState::Free && rt->state == TableState::Free) + else if (superTable->state == TableState::Free && subTable->state == TableState::Free) { - tryUnifyFreeTable(left, right); + tryUnifyFreeTable(subTy, superTy); // avoid creating a cycle when the types are already pointing at each other - if (follow(left) != follow(right)) + if (follow(superTy) != follow(subTy)) { - log(lt); - lt->boundTo = right; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(superTable); + superTable->boundTo = subTy; + } } return; } - else if (lt->state != TableState::Sealed && rt->state != TableState::Sealed) + else if (superTable->state != TableState::Sealed && subTable->state != TableState::Sealed) { // All free tables are checked in one of the branches above - LUAU_ASSERT(lt->state != TableState::Free); - LUAU_ASSERT(rt->state != TableState::Free); + LUAU_ASSERT(superTable->state != TableState::Free); + LUAU_ASSERT(subTable->state != TableState::Free); // Tables must have exactly the same props and their types must all unify // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : lt->props) + for (const auto& [name, prop] : superTable->props) { - const auto& r = rt->props.find(name); - if (r == rt->props.end()) - errors.push_back(TypeError{location, UnknownProperty{right, name}}); + const auto& r = subTable->props.find(name); + if (r == subTable->props.end()) + errors.push_back(TypeError{location, UnknownProperty{subTy, name}}); else - tryUnify_(prop.type, r->second.type); + tryUnify_(r->second.type, prop.type); } - if (lt->indexer && rt->indexer) - tryUnify(*lt->indexer, *rt->indexer); - else if (lt->indexer) + if (superTable->indexer && subTable->indexer) + tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (superTable->indexer) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (rt->state == TableState::Unsealed) - rt->indexer = lt->indexer; + if (subTable->state == TableState::Unsealed) + { + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(subTy, superTable->indexer); + } + else + { + subTable->indexer = superTable->indexer; + } + } else - errors.push_back(TypeError{location, CannotExtendTable{right, CannotExtendTable::Indexer}}); + errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); } } - else if (lt->state == TableState::Sealed) + else if (superTable->state == TableState::Sealed) { // lt is sealed and so it must be possible for rt to have precisely the same shape // Verify that this is the case, then bind rt to lt. ice("unsealed tables are not working yet", location); } - else if (rt->state == TableState::Sealed) - return tryUnifyTables(right, left, isIntersection); + else if (subTable->state == TableState::Sealed) + return tryUnifyTables(superTy, subTy, isIntersection); else ice("tryUnifyTables"); } -void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) +void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { - TableTypeVar* freeTable = getMutable(freeTypeId); - TableTypeVar* otherTable = getMutable(otherTypeId); - if (!freeTable || !otherTable) + TableTypeVar* freeTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + freeTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!freeTable || !subTable) ice("passed non-table types to tryUnifyFreeTable"); // Any properties in freeTable must unify with those in otherTable. // Then bind freeTable to otherTable. for (const auto& [freeName, freeProp] : freeTable->props) { - if (auto otherProp = findTablePropertyRespectingMeta(otherTypeId, freeName)) + if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - tryUnify_(*otherProp, freeProp.type); + tryUnify_(freeProp.type, *subProp); /* * TypeVars are commonly cyclic, so it is entirely possible @@ -1508,84 +2139,133 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (!get(freeTypeId) || !get(otherTypeId)) - return tryUnify_(freeTypeId, otherTypeId); + if (FFlag::LuauUseCommittingTxnLog) + { + if (!log.getMutable(superTy) || !log.getMutable(subTy)) + return tryUnify_(subTy, superTy); - if (freeTable->boundTo) - return tryUnify_(freeTypeId, otherTypeId); + if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) + return tryUnify_(subTy, superTy); + } + else + { + if (!get(superTy) || !get(subTy)) + return tryUnify_(subTy, superTy); + + if (freeTable->boundTo) + return tryUnify_(subTy, superTy); + } } else { // If the other table is also free, then we are learning that it has more // properties than we previously thought. Else, it is an error. - if (otherTable->state == TableState::Free) - otherTable->props.insert({freeName, freeProp}); + if (subTable->state == TableState::Free) + { + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* pendingSubTtv = getMutable(pendingSub); + LUAU_ASSERT(pendingSubTtv); + pendingSubTtv->props.insert({freeName, freeProp}); + } + else + { + subTable->props.insert({freeName, freeProp}); + } + } else - errors.push_back(TypeError{location, UnknownProperty{otherTypeId, freeName}}); + errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}}); } } - if (freeTable->indexer && otherTable->indexer) + if (freeTable->indexer && subTable->indexer) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*freeTable->indexer, *otherTable->indexer); + innerState.tryUnifyIndexer(*subTable->indexer, *freeTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, freeTypeId, otherTypeId); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + else if (subTable->state == TableState::Free && freeTable->indexer) + { + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(superTy, subTable->indexer); + } + else + { + freeTable->indexer = subTable->indexer; + } } - else if (otherTable->state == TableState::Free && freeTable->indexer) - freeTable->indexer = otherTable->indexer; - if (!freeTable->boundTo && otherTable->state != TableState::Free) + if (!freeTable->boundTo && subTable->state != TableState::Free) { - log(freeTable); - freeTable->boundTo = otherTypeId; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(freeTable); + freeTable->boundTo = subTy; + } } } -void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); Unifier innerState = makeChildUnifier(); std::vector missingPropertiesInSuper; - bool isUnnamedTable = rt->name == std::nullopt && rt->syntheticName == std::nullopt; + bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; bool errorReported = false; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer) + if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer) { - for (const auto& [propName, superProp] : lt->props) + for (const auto& [propName, superProp] : superTable->props) { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type)) + auto subIter = subTable->props.find(propName); + if (subIter == subTable->props.end() && !isOptional(superProp.type)) missingPropertiesInSuper.push_back(propName); } if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } } // Tables must have exactly the same props and their types must all unify - for (const auto& it : lt->props) + for (const auto& it : superTable->props) { - const auto& r = rt->props.find(it.first); - if (r == rt->props.end()) + const auto& r = subTable->props.find(it.first); + if (r == subTable->props.end()) { if (isOptional(it.second.type)) continue; missingPropertiesInSuper.push_back(it.first); - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -1594,7 +2274,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio size_t oldErrorSize = innerState.errors.size(); Location old = innerState.location; innerState.location = *r->second.location; - innerState.tryUnify_(it.second.type, r->second.type); + innerState.tryUnify_(r->second.type, it.second.type); innerState.location = old; if (oldErrorSize != innerState.errors.size() && !errorReported) @@ -1605,113 +2285,165 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else { - innerState.tryUnify_(it.second.type, r->second.type); + innerState.tryUnify_(r->second.type, it.second.type); } } } - if (lt->indexer || rt->indexer) + if (superTable->indexer || subTable->indexer) { - if (lt->indexer && rt->indexer) - innerState.tryUnify(*lt->indexer, *rt->indexer); - else if (rt->state == TableState::Unsealed) + if (FFlag::LuauUseCommittingTxnLog) { - if (lt->indexer && !rt->indexer) - rt->indexer = lt->indexer; - } - else if (lt->state == TableState::Unsealed) - { - if (rt->indexer && !lt->indexer) - lt->indexer = rt->indexer; + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) + { + if (superTable->indexer && !subTable->indexer) + { + log.changeIndexer(subTy, superTable->indexer); + } + } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) + { + log.changeIndexer(superTy, subTable->indexer); + } + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + for (const auto& [name, type] : subTable->props) + { + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - else if (lt->indexer) + else { - innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); - // We already try to unify properties in both tables. - // Skip those and just look for the ones remaining and see if they fit into the indexer. - for (const auto& [name, type] : rt->props) + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) { - const auto& it = lt->props.find(name); - if (it == lt->props.end()) - innerState.tryUnify_(lt->indexer->indexResultType, type.type); + if (superTable->indexer && !subTable->indexer) + subTable->indexer = superTable->indexer; } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) + superTable->indexer = subTable->indexer; + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + // We already try to unify properties in both tables. + // Skip those and just look for the ones remaining and see if they fit into the indexer. + for (const auto& [name, type] : subTable->props) + { + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - else - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (!errorReported) + log.concat(std::move(innerState.log)); + } + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); if (errorReported) return; if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } - // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. + // If the superTy is an immediate part of an intersection type, do not do extra-property check. // Otherwise, we would falsely generate an extra-property-error for 's' in this code: // local a: {n: number} & {s: string} = {n=1, s=""} // When checking against the table '{n: number}'. - if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) + if (!isIntersection && superTable->state != TableState::Unsealed && !superTable->indexer) { // Check for extra properties in the subTy std::vector extraPropertiesInSub; - for (const auto& it : rt->props) + for (const auto& [subKey, subProp] : subTable->props) { - const auto& r = lt->props.find(it.first); - if (r == lt->props.end()) + const auto& superIt = superTable->props.find(subKey); + if (superIt == superTable->props.end()) { - if (isOptional(it.second.type)) + if (isOptional(subProp.type)) continue; - extraPropertiesInSub.push_back(it.first); + extraPropertiesInSub.push_back(subKey); } } if (!extraPropertiesInSub.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } } - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); } -void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed) +void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { - const MetatableTypeVar* lhs = get(metatable); - if (!lhs) + const MetatableTypeVar* superMetatable = get(superTy); + if (!superMetatable) ice("tryUnifyMetatable invoked with non-metatable TypeVar"); - TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other}}; + TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy}}; - if (const MetatableTypeVar* rhs = get(other)) + if (const MetatableTypeVar* subMetatable = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(lhs->table, rhs->table); - innerState.tryUnify_(lhs->metatable, rhs->metatable); + innerState.tryUnify_(subMetatable->table, superMetatable->table); + innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); if (auto e = hasUnificationTooComplex(innerState.errors)) errors.push_back(*e); else if (!innerState.errors.empty()) errors.push_back( - TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); + TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } - else if (TableTypeVar* rhs = getMutable(other)) + else if (TableTypeVar* subTable = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : getMutable(subTy)) { - switch (rhs->state) + switch (subTable->state) { case TableState::Free: { - tryUnify_(lhs->table, other); - rhs->boundTo = metatable; + tryUnify_(subTy, superMetatable->table); + + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + subTable->boundTo = superTy; + } break; } @@ -1722,7 +2454,8 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse errors.push_back(mismatchError); } } - else if (get(other) || get(other)) + else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) + : (get(subTy) || get(subTy))) { } else @@ -1732,7 +2465,7 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse } // Class unification is almost, but not quite symmetrical. We use the 'reversed' boolean to indicate which scenario we are evaluating. -void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) +void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { if (reversed) std::swap(superTy, subTy); @@ -1763,7 +2496,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) } ice("Illegal variance setting!"); } - else if (TableTypeVar* table = getMutable(subTy)) + else if (TableTypeVar* subTable = getMutable(subTy)) { /** * A free table is something whose shape we do not exactly know yet. @@ -1775,12 +2508,12 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) * * Tables that are not free are known to be actual tables. */ - if (table->state != TableState::Free) + if (subTable->state != TableState::Free) return fail(); bool ok = true; - for (const auto& [propName, prop] : table->props) + for (const auto& [propName, prop] : subTable->props) { const Property* classProp = lookupClassProp(superClass, propName); if (!classProp) @@ -1791,23 +2524,37 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) else { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, classProp->type); + innerState.tryUnify_(classProp->type, prop.type); checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) + if (FFlag::LuauUseCommittingTxnLog) { - log.concat(std::move(innerState.log)); + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + } } else { - ok = false; - innerState.log.rollback(); + if (innerState.errors.empty()) + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + else + { + ok = false; + innerState.DEPRECATED_log.rollback(); + } } } } - if (table->indexer) + if (subTable->indexer) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; @@ -1817,17 +2564,24 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) if (!ok) return; - log(table); - table->boundTo = superTy; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + DEPRECATED_log(subTable); + subTable->boundTo = superTy; + } } else return fail(); } -void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer) +void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { - tryUnify_(superIndexer.indexType, subIndexer.indexType); - tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); + tryUnify_(subIndexer.indexType, superIndexer.indexType); + tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); } static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -1840,54 +2594,85 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (get(a)) + if (FFlag::LuauUseCommittingTxnLog) { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; + if (state.log.getMutable(a)) + { + state.log.replace(a, Unifiable::Bound{anyTypePack}); + } + else if (auto tp = state.log.getMutable(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } } - else if (auto tp = get(a)) + else { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; + if (get(a)) + { + state.DEPRECATED_log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } } } } -void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) +void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { - const VariadicTypePack* lv = get(superTp); - if (!lv) + const VariadicTypePack* superVariadic = get(superTp); + + if (FFlag::LuauUseCommittingTxnLog) + { + superVariadic = log.getMutable(superTp); + } + + if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); - if (const VariadicTypePack* rv = get(subTp)) - tryUnify_(reversed ? rv->ty : lv->ty, reversed ? lv->ty : rv->ty); + if (const VariadicTypePack* subVariadic = get(subTp)) + tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); else if (get(subTp)) { - TypePackIterator rIter = begin(subTp); - TypePackIterator rEnd = end(subTp); + TypePackIterator subIter = begin(subTp, &log); + TypePackIterator subEnd = end(subTp); - std::advance(rIter, subOffset); + std::advance(subIter, subOffset); - while (rIter != rEnd) + while (subIter != subEnd) { - tryUnify_(reversed ? *rIter : lv->ty, reversed ? lv->ty : *rIter); - ++rIter; + tryUnify_(reversed ? superVariadic->ty : *subIter, reversed ? *subIter : superVariadic->ty); + ++subIter; } - if (std::optional maybeTail = rIter.tail()) + if (std::optional maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); if (get(tail)) { - log(tail); - *asMutable(tail) = BoundTypePack{superTp}; + if (FFlag::LuauUseCommittingTxnLog) + { + log.replace(tail, BoundTypePack(superTp)); + } + else + { + DEPRECATED_log(tail); + *asMutable(tail) = BoundTypePack{superTp}; + } } else if (const VariadicTypePack* vtp = get(tail)) { - tryUnify_(lv->ty, vtp->ty); + tryUnify_(vtp->ty, superVariadic->ty); } else if (get(tail)) { @@ -1914,65 +2699,113 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { while (!queue.empty()) { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.find(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) + if (FFlag::LuauUseCommittingTxnLog) { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + TypeId ty = state.log.follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); - if (table->indexer) + if (state.log.getMutable(ty)) { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); + state.log.replace(ty, BoundTypeVar{anyType}); } + else if (auto fun = state.log.getMutable(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = state.log.getMutable(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = state.log.getMutable(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (state.log.getMutable(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = state.log.getMutable(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = state.log.getMutable(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); else { - } // Primitives, any, errors, and generics are left untouched. + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + + if (get(ty)) + { + state.DEPRECATED_log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } } } -void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) +void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(any) || get(any)); + LUAU_ASSERT(get(anyTy) || get(anyTy)); // These types are not visited in general loop below - if (get(ty) || get(ty) || get(ty)) + if (get(subTy) || get(subTy) || get(subTy)) return; const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); - const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + const TypePackId anyTP = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::vector queue = {ty}; + std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); @@ -1980,9 +2813,9 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } -void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) +void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { - LUAU_ASSERT(get(any)); + LUAU_ASSERT(get(anyTp)); const TypeId anyTy = getSingletonTypes().errorRecoveryType(); @@ -1991,9 +2824,9 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, subTy, anyTp); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, anyTp); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -2012,54 +2845,105 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - needle = follow(needle); - haystack = follow(haystack); + auto check = [&](TypeId tv) { + occursCheck(seen, needle, tv); + }; - if (seen.find(haystack)) - return; + if (FFlag::LuauUseCommittingTxnLog) + { + needle = log.follow(needle); + haystack = log.follow(haystack); - seen.insert(haystack); + if (seen.find(haystack)) + return; - if (get(needle)) - return; + seen.insert(haystack); - if (!get(needle)) - ice("Expected needle to be free"); + if (log.getMutable(needle)) + return; - if (needle == haystack) - { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); - return; - } + if (!log.getMutable(needle)) + ice("Expected needle to be free"); - auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); - }; + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryType()); - if (get(haystack)) - return; - else if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + return; + } + + if (log.getMutable(haystack)) + return; + else if (auto a = log.getMutable(haystack)) { - for (TypeId ty : a->argTypes) - check(ty); + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it) + check(*it); - for (TypeId ty : a->retType) + for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it) + check(*it); + } + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) check(ty); } } - else if (auto a = get(haystack)) - { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = get(haystack)) + else { - for (TypeId ty : a->parts) - check(ty); + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; + + if (!get(needle)) + ice("Expected needle to be free"); + + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + DEPRECATED_log(needle); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); + return; + } + + if (get(haystack)) + return; + else if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypeId ty : a->argTypes) + check(ty); + + for (TypeId ty : a->retType) + check(ty); + } + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } } @@ -2072,59 +2956,115 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { - needle = follow(needle); - haystack = follow(haystack); + if (FFlag::LuauUseCommittingTxnLog) + { + needle = log.follow(needle); + haystack = log.follow(haystack); - if (seen.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); - if (get(needle)) - return; + if (log.getMutable(needle)) + return; - if (!get(needle)) - ice("Expected needle pack to be free"); + if (!get(needle)) + ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - while (!get(haystack)) - { - if (needle == haystack) + while (!log.getMutable(haystack)) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); - return; - } + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); - if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + return; + } + + if (auto a = get(haystack)) { for (const auto& ty : a->head) { - if (auto f = get(follow(ty))) + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + if (auto f = log.getMutable(log.follow(ty))) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } } } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } } + break; + } + } + else + { + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; - if (a->tail) + if (!get(needle)) + ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + + while (!get(haystack)) + { + if (needle == haystack) { - haystack = follow(*a->tail); - continue; + errors.push_back(TypeError{location, OccursCheckFailed{}}); + DEPRECATED_log(needle); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); } + + if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (const auto& ty : a->head) + { + if (auto f = get(follow(ty))) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } + } + } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } + } + break; } - break; } } Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState}; + if (FFlag::LuauUseCommittingTxnLog) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, &log}; + else + return Unifier{types, mode, globalScope, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Common.h b/Ast/include/Luau/Common.h index 63cd3df49..fbb03a9e9 100644 --- a/Ast/include/Luau/Common.h +++ b/Ast/include/Luau/Common.h @@ -29,7 +29,7 @@ namespace Luau { -using AssertHandler = int (*)(const char* expression, const char* file, int line); +using AssertHandler = int (*)(const char* expression, const char* file, int line, const char* function); inline AssertHandler& assertHandler() { @@ -37,10 +37,10 @@ inline AssertHandler& assertHandler() return handler; } -inline int assertCallHandler(const char* expression, const char* file, int line) +inline int assertCallHandler(const char* expression, const char* file, int line, const char* function) { if (AssertHandler handler = assertHandler()) - return handler(expression, file, line); + return handler(expression, file, line, function); return 1; } @@ -48,7 +48,7 @@ inline int assertCallHandler(const char* expression, const char* file, int line) } // namespace Luau #if !defined(NDEBUG) || defined(LUAU_ENABLE_ASSERT) -#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__) && (LUAU_DEBUGBREAK(), 0)))) +#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__, __FUNCTION__) && (LUAU_DEBUGBREAK(), 0)))) #define LUAU_ASSERTENABLED #else #define LUAU_ASSERT(expr) (void)sizeof(!!(expr)) diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index aecb619a1..54a9a26fa 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -107,7 +107,7 @@ static void displayHelp(const char* argv0) printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 35c02f2c8..26d4333a9 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -235,11 +235,14 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, while (lua_next(L, -2) != 0) { - // table, key, value - std::string_view key = lua_tostring(L, -2); + if (lua_type(L, -2) == LUA_TSTRING) + { + // table, key, value + std::string_view key = lua_tostring(L, -2); - if (!key.empty() && Luau::startsWith(key, prefix)) - completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + if (!key.empty() && Luau::startsWith(key, prefix)) + completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + } lua_pop(L, 1); } @@ -253,7 +256,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_rawget(L, -2); lua_remove(L, -2); - if (lua_isnil(L, -1)) + if (!lua_istable(L, -1)) break; lookup.remove_prefix(dot + 1); @@ -266,7 +269,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) { size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.')) + while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == '_')) start--; // look the value up in current global table first @@ -278,6 +281,34 @@ static void completeRepl(lua_State* L, const char* editBuffer, std::vector globalState(luaL_newstate(), lua_close); @@ -292,6 +323,7 @@ static void runRepl() }); std::string buffer; + LinenoiseScopedHistory scopedHistory; for (;;) { @@ -457,7 +489,7 @@ static void displayHelp(const char* argv0) printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; diff --git a/CLI/Web.cpp b/CLI/Web.cpp index cf5c831e9..416a79f2b 100644 --- a/CLI/Web.cpp +++ b/CLI/Web.cpp @@ -53,30 +53,38 @@ static std::string runCode(lua_State* L, const std::string& source) lua_insert(T, 1); lua_pcall(T, n, 0, 0); } + + lua_pop(L, 1); // pop T + return std::string(); } else { std::string error; + lua_Debug ar; + if (lua_getinfo(L, 0, "sln", &ar)) + { + error += ar.short_src; + error += ':'; + error += std::to_string(ar.currentline); + error += ": "; + } + if (status == LUA_YIELD) { - error = "thread yielded unexpectedly"; + error += "thread yielded unexpectedly"; } else if (const char* str = lua_tostring(T, -1)) { - error = str; + error += str; } error += "\nstack backtrace:\n"; error += lua_debugtrace(T); - error = "Error:" + error; - - fprintf(stdout, "%s", error.c_str()); + lua_pop(L, 1); // pop T + return error; } - - lua_pop(L, 1); - return std::string(); } extern "C" const char* executeScript(const char* source) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 71631d101..d9694d7d0 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -377,6 +377,7 @@ enum LuauBytecodeTag { // Bytecode version LBC_VERSION = 1, + LBC_VERSION_FUTURE = 2, // TODO: This will be removed in favor of LBC_VERSION with LuauBytecodeV2Force // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index d4ebad6bf..287bf4eed 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -74,6 +74,7 @@ class BytecodeBuilder void expandJumps(); void setDebugFunctionName(StringRef name); + void setDebugFunctionLineDefined(int line); void setDebugLine(int line); void pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushDebugUpval(StringRef name); @@ -162,6 +163,7 @@ class BytecodeBuilder bool isvararg = false; unsigned int debugname = 0; + int debuglinedefined = 0; std::string dump; std::string dumpname; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 3280c8a40..2d31c409c 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Write, false) + namespace Luau { @@ -81,6 +83,52 @@ static int getOpLength(LuauOpcode op) } } +inline bool isJumpD(LuauOpcode op) +{ + switch (op) + { + case LOP_JUMP: + case LOP_JUMPIF: + case LOP_JUMPIFNOT: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_FORNPREP: + case LOP_FORNLOOP: + case LOP_FORGLOOP: + case LOP_FORGPREP_INEXT: + case LOP_FORGLOOP_INEXT: + case LOP_FORGPREP_NEXT: + case LOP_FORGLOOP_NEXT: + case LOP_JUMPBACK: + case LOP_JUMPIFEQK: + case LOP_JUMPIFNOTEQK: + return true; + + default: + return false; + } +} + +inline bool isSkipC(LuauOpcode op) +{ + switch (op) + { + case LOP_LOADB: + case LOP_FASTCALL: + case LOP_FASTCALL1: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + return true; + + default: + return false; + } +} + bool BytecodeBuilder::StringRef::operator==(const StringRef& other) const { return (data && other.data) ? (length == other.length && memcmp(data, other.data, length) == 0) : (data == other.data); @@ -365,13 +413,7 @@ bool BytecodeBuilder::patchJumpD(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_JUMP || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIF || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLT || - LUAU_INSN_OP(jumpInsn) == LOP_FORNPREP || LUAU_INSN_OP(jumpInsn) == LOP_FORNLOOP || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP || - LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_INEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_INEXT || - LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_NEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_NEXT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPBACK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQK); + LUAU_ASSERT(isJumpD(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_D(jumpInsn) == 0); LUAU_ASSERT(targetLabel <= insns.size()); @@ -403,8 +445,7 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL1 || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2 || - LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2K); + LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0); int offset = int(targetLabel) - int(jumpLabel) - 1; @@ -428,6 +469,11 @@ void BytecodeBuilder::setDebugFunctionName(StringRef name) functions[currentFunction].dumpname = std::string(name.data, name.length); } +void BytecodeBuilder::setDebugFunctionLineDefined(int line) +{ + functions[currentFunction].debuglinedefined = line; +} + void BytecodeBuilder::setDebugLine(int line) { debugLine = line; @@ -464,7 +510,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(LBC_VERSION); + bytecode = char(FFlag::LuauBytecodeV2Write ? LBC_VERSION_FUTURE : LBC_VERSION); writeStringTable(bytecode); @@ -565,6 +611,9 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const writeVarInt(ss, child); // debug info + if (FFlag::LuauBytecodeV2Write) + writeVarInt(ss, func.debuglinedefined); + writeVarInt(ss, func.debugname); bool hasLines = true; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8f74ffedd..6ae490273 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) -LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) namespace Luau { @@ -179,6 +178,8 @@ struct Compiler if (options.optimizationLevel >= 1 && options.debugLevel >= 2) gatherConstUpvals(func); + bytecode.setDebugFunctionLineDefined(func->location.begin.line + 1); + if (options.debugLevel >= 1 && func->debugname.value) bytecode.setDebugFunctionName(sref(func->debugname)); @@ -3626,9 +3627,9 @@ struct Compiler return LBF_BIT32_RROTATE; if (builtin.method == "rshift") return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin) + if (builtin.method == "countlz") return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin) + if (builtin.method == "countrz") return LBF_BIT32_COUNTRZ; } diff --git a/Sources.cmake b/Sources.cmake index a7153eb37..5dd486aaa 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -125,6 +125,7 @@ target_sources(Luau.VM PRIVATE VM/src/linit.cpp VM/src/lmathlib.cpp VM/src/lmem.cpp + VM/src/lnumprint.cpp VM/src/lobject.cpp VM/src/loslib.cpp VM/src/lperf.cpp diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index a01a14819..7e0832e7c 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -138,10 +138,6 @@ /* }================================================================== */ -/* Default number printing format and the string length limit */ -#define LUA_NUMBER_FMT "%.14g" -#define LUAI_MAXNUMBER2STR 32 /* 16 digits, sign, point, and \0 */ - /* @@ LUAI_USER_ALIGNMENT_T is a type that requires maximum alignment. ** CHANGE it if your system requires alignments larger than double. (For diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index a65b03253..c98b95908 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauActivateBeforeExec) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -939,21 +937,7 @@ void lua_call(lua_State* L, int nargs, int nresults) checkresults(L, nargs, nresults); func = L->top - (nargs + 1); - if (FFlag::LuauActivateBeforeExec) - { - luaD_call(L, func, nresults); - } - else - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); - - luaD_call(L, func, nresults); - - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + luaD_call(L, func, nresults); adjustresults(L, nresults); return; @@ -994,21 +978,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - if (FFlag::LuauActivateBeforeExec) - { - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - } - else - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); - - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); adjustresults(L, nresults); return status; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 7ed2a62ee..71975a520 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -7,9 +7,12 @@ #include "lstring.h" #include "lapi.h" #include "lgc.h" +#include "lnumutils.h" #include +LUAU_FASTFLAG(LuauSchubfach) + /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -477,7 +480,17 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) switch (lua_type(L, idx)) { case LUA_TNUMBER: - lua_pushstring(L, lua_tostring(L, idx)); + if (FFlag::LuauSchubfach) + { + double n = lua_tonumber(L, idx); + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, n); + lua_pushlstring(L, s, e - s); + } + else + { + lua_pushstring(L, lua_tostring(L, idx)); + } break; case LUA_TSTRING: lua_pushvalue(L, idx); @@ -491,11 +504,30 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); + + if (FFlag::LuauSchubfach) + { + char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; + char* e = s; + for (int i = 0; i < LUA_VECTOR_SIZE; ++i) + { + if (i != 0) + { + *e++ = ','; + *e++ = ' '; + } + e = luai_num2str(e, v[i]); + } + lua_pushlstring(L, s, e - s); + } + else + { #if LUA_VECTOR_SIZE == 4 - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); #else - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); #endif + } break; } default: diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 8b511edf0..093400f2e 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -5,8 +5,6 @@ #include "lcommon.h" #include "lnumutils.h" -LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) - #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -182,9 +180,6 @@ static int b_replace(lua_State* L) static int b_countlz(lua_State* L) { - if (!FFlag::LuauBit32Count) - luaL_error(L, "bit32.countlz isn't enabled"); - b_uint v = luaL_checkunsigned(L, 1); b_uint r = NBITS; @@ -201,9 +196,6 @@ static int b_countlz(lua_State* L) static int b_countrz(lua_State* L) { - if (!FFlag::LuauBit32Count) - luaL_error(L, "bit32.countrz isn't enabled"); - b_uint v = luaL_checkunsigned(L, 1); b_uint r = NBITS; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 9fe1885fb..2b5382bba 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,9 @@ #include #include +LUAU_FASTFLAG(LuauBytecodeV2Read) +LUAU_FASTFLAG(LuauBytecodeV2Force) + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -89,6 +92,16 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } +static int getlinedefined(Proto* p) +{ + if (FFlag::LuauBytecodeV2Force) + return p->linedefined; + else if (FFlag::LuauBytecodeV2Read && p->linedefined >= 0) + return p->linedefined; + else + return luaG_getline(p, 0); +} + static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { int status = 1; @@ -108,7 +121,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, { ar->source = getstr(f->l.p->source); ar->what = "Lua"; - ar->linedefined = luaG_getline(f->l.p, 0); + ar->linedefined = getlinedefined(f->l.p); } luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; @@ -121,7 +134,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, } else { - ar->currentline = f->isC ? -1 : luaG_getline(f->l.p, 0); + ar->currentline = f->isC ? -1 : getlinedefined(f->l.p); } break; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 62bbdb7c9..eb47971a8 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,7 +19,6 @@ LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) -LUAU_FASTFLAGVARIABLE(LuauActivateBeforeExec, false) /* ** {====================================================== @@ -228,21 +227,14 @@ void luaD_call(lua_State* L, StkId func, int nResults) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ - if (FFlag::LuauActivateBeforeExec) - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - luau_execute(L); /* call it */ + luau_execute(L); /* call it */ - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } - else - { - luau_execute(L); /* call it */ - } + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); } L->nCcalls--; luaC_checkGC(L); @@ -549,12 +541,9 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } - if (FFlag::LuauActivateBeforeExec) - { - // since the call failed with an error, we might have to reset the 'active' thread state - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + // since the call failed with an error, we might have to reset the 'active' thread state + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); if (FFlag::LuauCcallRestoreFix) { diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 7393fc74f..76ef7a06b 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,8 +12,6 @@ #include -LUAU_FASTFLAG(LuauArrayBoundary) - #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index f6f7a878f..c66de9c1d 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauArrayBoundary) - static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -38,10 +36,7 @@ static void validatetable(global_State* g, Table* h) { int sizenode = 1 << h->lsizenode; - if (FFlag::LuauArrayBoundary) - LUAU_ASSERT(h->lastfree <= sizenode); - else - LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + LUAU_ASSERT(h->lastfree <= sizenode); if (h->metatable) validateobjref(g, obj2gco(h), obj2gco(h->metatable)); diff --git a/VM/src/lnumprint.cpp b/VM/src/lnumprint.cpp new file mode 100644 index 000000000..2fd0f1bbd --- /dev/null +++ b/VM/src/lnumprint.cpp @@ -0,0 +1,375 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "luaconf.h" +#include "lnumutils.h" + +#include "lcommon.h" + +#include +#include // TODO: Remove with LuauSchubfach + +#ifdef _MSC_VER +#include +#endif + +// This work is based on: +// Raffaello Giulietti. The Schubfach way to render doubles. 2021 +// https://drive.google.com/file/d/1IEeATSVnEE6TkrHlCYNY2GjaraBjOT4f/edit + +// The code uses the notation from the paper for local variables where appropriate, and refers to paper sections/figures/results. + +LUAU_FASTFLAGVARIABLE(LuauSchubfach, false) + +// 9.8.2. Precomputed table for 128-bit overestimates of powers of 10 (see figure 3 for table bounds) +// To avoid storing 616 128-bit numbers directly we use a technique inspired by Dragonbox implementation and store 16 consecutive +// powers using a 128-bit baseline and a bitvector with 1-bit scale and 3-bit offset for the delta between each entry and base*5^k +static const int kPow10TableMin = -292; +static const int kPow10TableMax = 324; + +// clang-format off +static const uint64_t kPow5Table[16] = { + 0x8000000000000000, 0xa000000000000000, 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, 0xc350000000000000, + 0xf424000000000000, 0x9896800000000000, 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, 0xba43b74000000000, + 0xe8d4a51000000000, 0x9184e72a00000000, 0xb5e620f480000000, 0xe35fa931a0000000, +}; +static const uint64_t kPow10Table[(kPow10TableMax - kPow10TableMin + 1 + 15) / 16][3] = { + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b, 0x333443443333443b}, {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4, 0xbbb3ab3cb3ba3cbc}, + {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa, 0x4ba4bc4bb4bb4bcc}, {0xaecc49914078536d, 0x58fae9f773886e19, 0x3ba3bc33b43b43bb}, + {0xc21094364dfb5636, 0x985915fc12f542e5, 0x33b43b43a33b33cb}, {0xd77485cb25823ac7, 0x7d633293366b828c, 0x34b44c444343443c}, + {0xef340a98172aace4, 0x86fb897116c87c35, 0x333343333343334b}, {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074, 0xccaccbbcbcbb4bbc}, + {0x936b9fcebb25c995, 0xcab10dd900beec35, 0x3ab3ab3ab3bb3bbb}, {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb, 0x4cc3dc4db4db4dbb}, + {0xb5b5ada8aaff80b8, 0x0d819992132456bb, 0x33b33a34c33b34ab}, {0xc9bcff6034c13052, 0xfc89b393dd02f0b6, 0x33c33b44b43c34bc}, + {0xdff9772470297ebd, 0x59787e2b93bc56f8, 0x43b444444443434c}, {0xf8a95fcf88747d94, 0x75a44c6397ce912b, 0x443334343443343b}, + {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900, 0xbbabab3aa3ab4ccc}, {0x993fe2c6d07b7fab, 0xe546a8038efe402a, 0x4cb4bc4db4db4bcc}, + {0xaa242499697392d2, 0xdde50bd1d5d0b9ea, 0x3ba3ba3bb33b33bc}, {0xbce5086492111aea, 0x88f4bb1ca6bcf585, 0x44b44c44c44c43cb}, + {0xd1b71758e219652b, 0xd3c36113404ea4a9, 0x44c44c44c444443b}, {0xe8d4a51000000000, 0x0000000000000000, 0x444444444444444c}, + {0x813f3978f8940984, 0x4000000000000000, 0xcccccccccccccccc}, {0x8f7e32ce7bea5c6f, 0xe4820023a2000000, 0xbba3bc4cc4cc4ccc}, + {0x9f4f2726179a2245, 0x01d762422c946591, 0x4aa3bb3aa3ba3bab}, {0xb0de65388cc8ada8, 0x3b25a55f43294bcc, 0x3ca33b33b44b43bc}, + {0xc45d1df942711d9a, 0x3ba5d0bd324f8395, 0x44c44c34c44b44cb}, {0xda01ee641a708de9, 0xe80e6f4820cc9496, 0x33b33b343333333c}, + {0xf209787bb47d6b84, 0xc0678c5dbd23a49b, 0x443444444443443b}, {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3, 0xdbccbcccb4cb3bbb}, + {0x952ab45cfa97a0b2, 0xdd945a747bf26184, 0x3bc4bb4ab3ca3cbc}, {0xa59bc234db398c25, 0x43fab9837e699096, 0x3bb3ac3ab3bb33ac}, + {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30, 0x33b43b43b34c34dc}, {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5, 0x34c44c43c44b44cb}, + {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e, 0x333333333333333c}, {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2, 0x433344443333344c}, + {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f, 0xdcbdcc3cc4cc4bcb}, {0x9b10a4e5e9913128, 0xca7cf2b4191c8327, 0x3ab3cb3bc3bb4bbb}, + {0xac2820d9623bf429, 0x546345fa9fbdcd45, 0x3bb3cc43c43c43cb}, {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4, 0x44b34a43b44c44bc}, + {0xd433179d9c8cb841, 0x5fa60692a46151ec, 0x43a33a33a333333c}, +}; +// clang-format on + +static const char kDigitTable[] = "0001020304050607080910111213141516171819202122232425262728293031323334353637383940414243444546474849" + "5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899"; + +// x*y => 128-bit product (lo+hi) +inline uint64_t mul128(uint64_t x, uint64_t y, uint64_t* hi) +{ +#if defined(_MSC_VER) && defined(_M_X64) + return _umul128(x, y, hi); +#elif defined(__SIZEOF_INT128__) + unsigned __int128 r = x; + r *= y; + *hi = uint64_t(r >> 64); + return uint64_t(r); +#else + uint32_t x0 = uint32_t(x), x1 = uint32_t(x >> 32); + uint32_t y0 = uint32_t(y), y1 = uint32_t(y >> 32); + uint64_t p11 = uint64_t(x1) * y1, p01 = uint64_t(x0) * y1; + uint64_t p10 = uint64_t(x1) * y0, p00 = uint64_t(x0) * y0; + uint64_t mid = p10 + (p00 >> 32) + uint32_t(p01); + uint64_t r0 = (mid << 32) | uint32_t(p00); + uint64_t r1 = p11 + (mid >> 32) + (p01 >> 32); + *hi = r1; + return r0; +#endif +} + +// (x*y)>>64 => 128-bit product (lo+hi) +inline uint64_t mul192hi(uint64_t xhi, uint64_t xlo, uint64_t y, uint64_t* hi) +{ + uint64_t z2; + uint64_t z1 = mul128(xhi, y, &z2); + + uint64_t z1c; + uint64_t z0 = mul128(xlo, y, &z1c); + (void)z0; + + z1 += z1c; + z2 += (z1 < z1c); + + *hi = z2; + return z1; +} + +// 9.3. Rounding to odd (+ figure 8 + result 23) +inline uint64_t roundodd(uint64_t ghi, uint64_t glo, uint64_t cp) +{ + uint64_t xhi; + uint64_t xlo = mul128(glo, cp, &xhi); + (void)xlo; + + uint64_t yhi; + uint64_t ylo = mul128(ghi, cp, &yhi); + + uint64_t z = ylo + xhi; + return (yhi + (z < xhi)) | (z > 1); +} + +struct Decimal +{ + uint64_t s; + int k; +}; + +static Decimal schubfach(int exponent, uint64_t fraction) +{ + // Extract c & q such that c*2^q == |v| + uint64_t c = fraction; + int q = exponent - 1023 - 51; + + if (exponent != 0) // normal numbers have implicit leading 1 + { + c |= (1ull << 52); + q--; + } + + // 8.3. Fast path for integers + if (unsigned(-q) < 53 && (c & ((1ull << (-q)) - 1)) == 0) + return {c >> (-q), 0}; + + // 5. Rounding interval + int irr = (c == (1ull << 52) && q != -1074); // Qmin + int out = int(c & 1); + + // 9.8.1. Boundaries for c + uint64_t cbl = 4 * c - 2 + irr; + uint64_t cb = 4 * c; + uint64_t cbr = 4 * c + 2; + + // 9.1. Computing k and h + const int Q = 20; + const int C = 315652; // floor(2^Q * log10(2)) + const int A = -131008; // floor(2^Q * log10(3/4)) + const int C2 = 3483294; // floor(2^Q * log2(10)) + int k = (q * C + (irr ? A : 0)) >> Q; + int h = q + ((-k * C2) >> Q) + 1; // see (9) in 9.9 + + // 9.8.2. Overestimates of powers of 10 + // Recover 10^-k fraction using compact tables generated by tools/numutils.py + // The 128-bit fraction is encoded as 128-bit baseline * power-of-5 * scale + offset + LUAU_ASSERT(-k >= kPow10TableMin && -k <= kPow10TableMax); + int gtoff = -k - kPow10TableMin; + const uint64_t* gt = kPow10Table[gtoff >> 4]; + + uint64_t ghi; + uint64_t glo = mul192hi(gt[0], gt[1], kPow5Table[gtoff & 15], &ghi); + + // Apply 1-bit scale + 3-bit offset; note, offset is intentionally applied without carry, numutils.py validates that this is sufficient + int gterr = (gt[2] >> ((gtoff & 15) * 4)) & 15; + int gtscale = gterr >> 3; + + ghi <<= gtscale; + ghi += (glo >> 63) & gtscale; + glo <<= gtscale; + glo -= (gterr & 7) - 4; + + // 9.9. Boundaries for v + uint64_t vbl = roundodd(ghi, glo, cbl << h); + uint64_t vb = roundodd(ghi, glo, cb << h); + uint64_t vbr = roundodd(ghi, glo, cbr << h); + + // Main algorithm; see figure 7 + figure 9 + uint64_t s = vb / 4; + + if (s >= 10) + { + uint64_t sp = s / 10; + + bool upin = vbl + out <= 40 * sp; + bool wpin = vbr >= 40 * sp + 40 + out; + + if (upin != wpin) + return {sp + wpin, k + 1}; + } + + // Figure 7 contains the algorithm to select between u (s) and w (s+1) + // rup computes the last 4 conditions in that algorithm + // rup is only used when uin == win, but since these branches predict poorly we use branchless selects + bool uin = vbl + out <= 4 * s; + bool win = 4 * s + 4 + out <= vbr; + bool rup = vb >= 4 * s + 2 + 1 - (s & 1); + + return {s + (uin != win ? win : rup), k}; +} + +static char* printspecial(char* buf, int sign, uint64_t fraction) +{ + if (fraction == 0) + { + memcpy(buf, ("-inf") + (1 - sign), 4); + return buf + 3 + sign; + } + else + { + memcpy(buf, "nan", 4); + return buf + 3; + } +} + +static char* printunsignedrev(char* end, uint64_t num) +{ + while (num >= 10000) + { + unsigned int tail = unsigned(num % 10000); + + memcpy(end - 4, &kDigitTable[int(tail / 100) * 2], 2); + memcpy(end - 2, &kDigitTable[int(tail % 100) * 2], 2); + num /= 10000; + end -= 4; + } + + unsigned int rest = unsigned(num); + + while (rest >= 10) + { + memcpy(end - 2, &kDigitTable[int(rest % 100) * 2], 2); + rest /= 100; + end -= 2; + } + + if (rest) + { + end[-1] = '0' + int(rest); + end -= 1; + } + + return end; +} + +static char* printexp(char* buf, int num) +{ + *buf++ = 'e'; + *buf++ = num < 0 ? '-' : '+'; + + int v = num < 0 ? -num : num; + + if (v >= 100) + { + *buf++ = '0' + (v / 100); + v %= 100; + } + + memcpy(buf, &kDigitTable[v * 2], 2); + return buf + 2; +} + +inline char* trimzero(char* end) +{ + while (end[-1] == '0') + end--; + + return end; +} + +// We use fixed-length memcpy/memset since they lower to fast SIMD+scalar writes; the target buffers should have padding space +#define fastmemcpy(dst, src, size, sizefast) check_exp((size) <= sizefast, memcpy(dst, src, sizefast)) +#define fastmemset(dst, val, size, sizefast) check_exp((size) <= sizefast, memset(dst, val, sizefast)) + +char* luai_num2str(char* buf, double n) +{ + if (!FFlag::LuauSchubfach) + { + snprintf(buf, LUAI_MAXNUM2STR, LUA_NUMBER_FMT, n); + return buf + strlen(buf); + } + + // IEEE-754 + union + { + double v; + uint64_t bits; + } v = {n}; + int sign = int(v.bits >> 63); + int exponent = int(v.bits >> 52) & 2047; + uint64_t fraction = v.bits & ((1ull << 52) - 1); + + // specials + if (LUAU_UNLIKELY(exponent == 0x7ff)) + return printspecial(buf, sign, fraction); + + // sign bit + *buf = '-'; + buf += sign; + + // zero + if (exponent == 0 && fraction == 0) + { + buf[0] = '0'; + return buf + 1; + } + + // convert binary to decimal using Schubfach + Decimal d = schubfach(exponent, fraction); + LUAU_ASSERT(d.s < uint64_t(1e17)); + + // print the decimal to a temporary buffer; we'll need to insert the decimal point and figure out the format + char decbuf[40]; + char* decend = decbuf + 20; // significand needs at most 17 digits; the rest of the buffer may be copied using fixed length memcpy + char* dec = printunsignedrev(decend, d.s); + + int declen = int(decend - dec); + LUAU_ASSERT(declen <= 17); + + int dot = declen + d.k; + + // the limits are somewhat arbitrary but changing them may require changing fastmemset/fastmemcpy sizes below + if (dot >= -5 && dot <= 21) + { + // fixed point format + if (dot <= 0) + { + buf[0] = '0'; + buf[1] = '.'; + + fastmemset(buf + 2, '0', -dot, 5); + fastmemcpy(buf + 2 + (-dot), dec, declen, 17); + + return trimzero(buf + 2 + (-dot) + declen); + } + else if (dot == declen) + { + // no dot + fastmemcpy(buf, dec, dot, 17); + + return buf + dot; + } + else if (dot < declen) + { + // dot in the middle + fastmemcpy(buf, dec, dot, 16); + + buf[dot] = '.'; + + fastmemcpy(buf + dot + 1, dec + dot, declen - dot, 16); + + return trimzero(buf + declen + 1); + } + else + { + // no dot, zero padding + fastmemcpy(buf, dec, declen, 17); + fastmemset(buf + declen, '0', dot - declen, 8); + + return buf + dot; + } + } + else + { + // scientific format + buf[0] = dec[0]; + buf[1] = '.'; + fastmemcpy(buf + 2, dec + 1, declen - 1, 16); + + char* exp = trimzero(buf + declen + 1); + + return printexp(exp, dot - 1); + } +} diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 67f832dc5..fba07bc3b 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -3,7 +3,6 @@ #pragma once #include -#include #define luai_numadd(a, b) ((a) + (b)) #define luai_numsub(a, b) ((a) - (b)) @@ -56,5 +55,9 @@ LUAU_FASTMATH_END #define luai_num2unsigned(i, n) ((i) = (unsigned)(long long)(n)) #endif -#define luai_num2str(s, n) snprintf((s), sizeof(s), LUA_NUMBER_FMT, (n)) +#define LUA_NUMBER_FMT "%.14g" /* TODO: Remove with LuauSchubfach */ +#define LUAI_MAXNUM2STR 48 + +LUAI_FUNC char* luai_num2str(char* buf, double n); + #define luai_str2num(s, p) strtod((s), (p)) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index fd0a15b75..b642cf787 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -289,6 +289,7 @@ typedef struct Proto int sizek; int sizelineinfo; int linegaplog2; + int linedefined; uint8_t nups; /* number of upvalues */ diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 0b55fceac..83b59f3fa 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -24,8 +24,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -222,7 +220,7 @@ int luaH_next(lua_State* L, Table* t, StkId key) #define maybesetaboundary(t, boundary) \ { \ - if (FFlag::LuauArrayBoundary && t->aboundary <= 0) \ + if (t->aboundary <= 0) \ t->aboundary = -int(boundary); \ } @@ -705,7 +703,7 @@ int luaH_getn(Table* t) { int boundary = getaboundary(t); - if (FFlag::LuauArrayBoundary && boundary > 0) + if (boundary > 0) { if (!ttisnil(&t->array[t->sizearray - 1]) && t->node == dummynode) return t->sizearray; /* fast-path: the end of the array in `t' already refers to a boundary */ diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index add3588d4..cdb276c06 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,6 +13,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) + // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -146,15 +149,19 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint8_t version = read(data, size, offset); // 0 means the rest of the bytecode is the error message - if (version == 0 || version != LBC_VERSION) + if (version == 0) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); + return 1; + } - if (version == 0) - lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); - else - lua_pushfstring(L, "%s: bytecode version mismatch", chunkid); + if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : FFlag::LuauBytecodeV2Read ? (version != LBC_VERSION && version != LBC_VERSION_FUTURE) : (version != LBC_VERSION)) + { + char chunkid[LUA_IDSIZE]; + luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); return 1; } @@ -285,6 +292,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } + if (FFlag::LuauBytecodeV2Force || (FFlag::LuauBytecodeV2Read && version == LBC_VERSION_FUTURE)) + p->linedefined = readVarInt(data, size, offset); + else + p->linedefined = -1; + p->debugname = readString(strings, data, size, offset); uint8_t lineinfo = read(data, size, offset); @@ -307,11 +319,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->lineinfo[j] = lastoffset; } - int lastLine = 0; + int lastline = 0; for (int j = 0; j < intervals; ++j) { - lastLine += read(data, size, offset); - p->abslineinfo[j] = lastLine; + lastline += read(data, size, offset); + p->abslineinfo[j] = lastline; } } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 5d802277a..31dd59c86 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -34,10 +34,11 @@ int luaV_tostring(lua_State* L, StkId obj) return 0; else { - char s[LUAI_MAXNUMBER2STR]; + char s[LUAI_MAXNUM2STR]; double n = nvalue(obj); - luai_num2str(s, n); - setsvalue2s(L, obj, luaS_new(L, s)); + char* e = luai_num2str(s, n); + LUAU_ASSERT(e < s + sizeof(s)); + setsvalue2s(L, obj, luaS_newlstr(L, s, e - s)); return 1; } } diff --git a/fuzz/number.cpp b/fuzz/number.cpp new file mode 100644 index 000000000..704474096 --- /dev/null +++ b/fuzz/number.cpp @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" + +#include +#include +#include +#include + +LUAU_FASTFLAG(LuauSchubfach); + +#define LUAI_MAXNUM2STR 48 + +char* luai_num2str(char* buf, double n); + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + if (Size < 8) + return 0; + + FFlag::LuauSchubfach.value = true; + + double num; + memcpy(&num, Data, 8); + + char buf[LUAI_MAXNUM2STR]; + char* end = luai_num2str(buf, num); + LUAU_ASSERT(end < buf + sizeof(buf)); + + *end = 0; + + double rec = strtod(buf, nullptr); + + LUAU_ASSERT(rec == num || (rec != rec && num != num)); + return 0; +} diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 2090b0148..41b553b55 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -83,8 +83,6 @@ TEST_SUITE_BEGIN("AstQuery"); TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") { - ScopedFastFlag luauTailArgumentTypeInfo{"LuauTailArgumentTypeInfo", true}; - check(R"( local function foo() return 2 end local function bar(a: number) return -a end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8ca09c0e0..210db7eea 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -1911,11 +1912,14 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") +// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_correct_function_no_parenthesis") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); + ACFixture fix; - check(R"( + fix.check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end @@ -1923,7 +1927,7 @@ local function bar2(a: string) return a .. 'x' end return target(b@1 )"); - auto ac = autocomplete('1'); + auto ac = fix.autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1976,11 +1980,14 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") +// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_correct_keywords") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); + ACFixture fix; - check(R"( + fix.check(R"( local function a(x: boolean) end local function b(x: number?) end local function c(x: (number) -> string) end @@ -1997,26 +2004,26 @@ local dc = d(f@4) local ec = e(f@5) )"); - auto ac = autocomplete('1'); + auto ac = fix.autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('2'); + ac = fix.autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('3'); + ac = fix.autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('4'); + ac = fix.autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('5'); + ac = fix.autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -2507,21 +2514,23 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") +TEST_CASE("autocomplete_documentation_symbols") { - loadDefinition(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + fix.loadDefinition(R"( declare y: { x: number, } )"); - fileResolver.source["Module/A"] = R"( + fix.fileResolver.source["Module/A"] = R"( local a = y. )"; - frontend.check("Module/A"); + fix.frontend.check("Module/A"); - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete(fix.frontend, "Module/A", Position{1, 21}, nullCallback); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b055a38e4..663b329ee 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -13,8 +13,11 @@ #include "ScopedFlags.h" #include +#include #include +extern bool verbose; + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; @@ -146,15 +149,21 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n luaL_openlibs(L); // Register a few global functions for conformance tests - static const luaL_Reg funcs[] = { + std::vector funcs = { {"collectgarbage", lua_collectgarbage}, {"loadstring", lua_loadstring}, - {"print", lua_silence}, // Disable print() by default; comment this out to enable debug prints in tests - {nullptr, nullptr}, }; + if (!verbose) + { + funcs.push_back({"print", lua_silence}); + } + + // "null" terminate the list of functions to register + funcs.push_back({nullptr, nullptr}); + lua_pushvalue(L, LUA_GLOBALSINDEX); - luaL_register(L, nullptr, funcs); + luaL_register(L, nullptr, funcs.data()); lua_pop(L, 1); // In some configurations we have a larger C stack consumption which trips some conformance tests @@ -312,8 +321,6 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { - ScopedFastFlag sff("LuauBit32Count", true); - runConformance("bitwise.lua"); } @@ -491,6 +498,9 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { + ScopedFastFlag sffr("LuauBytecodeV2Read", true); + ScopedFastFlag sffw("LuauBytecodeV2Write", true); + runConformance("debug.lua"); } @@ -738,8 +748,6 @@ TEST_CASE("ApiFunctionCalls") // lua_equal with a sleeping thread wake up { - ScopedFastFlag luauActivateBeforeExec("LuauActivateBeforeExec", true); - lua_State* L2 = lua_newthread(L); lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); @@ -913,4 +921,11 @@ TEST_CASE("Coverage") nullptr, nullptr, &copts); } +TEST_CASE("StringConversion") +{ + ScopedFastFlag sff{"LuauSchubfach", true}; + + runConformance("strconv.lua"); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 36d6f5612..ca4281a0b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -103,11 +103,6 @@ Fixture::~Fixture() Luau::resetPrintLine(); } -UnfrozenFixture::UnfrozenFixture() - : Fixture(false) -{ -} - AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& parseOptions) { sourceModule.reset(new SourceModule); diff --git a/tests/Fixture.h b/tests/Fixture.h index de2b7381e..e01632eab 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -152,15 +152,6 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; -// Disables arena freezing for a given test case. -// Do not use this in new tests. If you are running into access violations, you -// are violating Luau's memory model - the fix is not to use UnfrozenFixture. -// Related: CLI-45692 -struct UnfrozenFixture : Fixture -{ - UnfrozenFixture(); -}; - ModuleName fromString(std::string_view name); template diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 51fcd3d6c..405f26e07 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -914,6 +914,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { + ScopedFastFlag sffs("LuauSealExports", true); + frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -927,7 +929,7 @@ return a; --!nonstrict local a = require(script.Parent.A) local b = {} -function a:b() end -- this should error, but doesn't +function a:b() end -- this should error, since A doesn't define a:b() return b )"; @@ -942,8 +944,7 @@ a:b() -- this should error, since A doesn't define a:b() LUAU_REQUIRE_NO_ERRORS(resultA); CheckResult resultB = frontend.check("Module/B"); - // TODO (CLI-45592): this should error, since we shouldn't be adding properties to objects from other modules - LUAU_REQUIRE_NO_ERRORS(resultB); + LUAU_REQUIRE_ERRORS(resultB); CheckResult resultC = frontend.check("Module/C"); LUAU_REQUIRE_ERRORS(resultC); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 71ff4e1b8..275782b30 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -620,7 +620,7 @@ struct AssertionCatcher { tripped = 0; oldhook = Luau::assertHandler(); - Luau::assertHandler() = [](const char* expr, const char* file, int line) -> int { + Luau::assertHandler() = [](const char* expr, const char* file, int line, const char* function) -> int { ++tripped; return 0; }; diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 506279b94..5e08654a7 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -11,6 +11,8 @@ LUAU_FASTFLAG(LuauFixTonumberReturnType) using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -444,19 +446,28 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("thread_is_a_type") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local co = coroutine.create(function() end) )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.threadType, *requireType("co")); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); + CHECK_EQ(*fix.typeChecker.threadType, *fix.requireType("co")); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("coroutine_resume_anything_goes") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local function nifty(x, y) print(x, y) local z = coroutine.yield(1, 2) @@ -469,12 +480,17 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") local answer = coroutine.resume(co, 3) )"); - LUAU_REQUIRE_NO_ERRORS(result); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("coroutine_wrap_anything_goes") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( --!nonstrict local function nifty(x, y) print(x, y) @@ -488,7 +504,8 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") local answer = f(3) )"); - LUAU_REQUIRE_NO_ERRORS(result); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); } TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b62044fa9..114679e34 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -629,8 +629,6 @@ return exports TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") { - ScopedFastFlag luauFunctionArgumentNameSize{"LuauFunctionArgumentNameSize", true}; - CheckResult result = check(R"( local function f(a: T, ...: U...) end diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 76324556a..f70f3b1c8 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4803,8 +4803,6 @@ local bar = foo.nutrition + 100 TEST_CASE_FIXTURE(Fixture, "require_failed_module") { - ScopedFastFlag luauModuleRequireErrorPack{"LuauModuleRequireErrorPack", true}; - fileResolver.source["game/A"] = R"( return unfortunately() )"; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index f55b46a40..1e790eba9 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -12,6 +12,8 @@ LUAU_FASTFLAG(LuauQuantifyInPlace2); using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + struct TryUnifyFixture : Fixture { TypeArena arena; @@ -28,7 +30,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TypeVar numberOne{TypeVariant{PrimitiveTypeVar{PrimitiveTypeVar::Number}}}; TypeVar numberTwo = numberOne; - state.tryUnify(&numberOne, &numberTwo); + state.tryUnify(&numberTwo, &numberOne); CHECK(state.errors.empty()); } @@ -41,9 +43,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TypeVar functionTwo{TypeVariant{ FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + CHECK_EQ(functionOne, functionTwo); } @@ -61,7 +66,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TypeVar functionTwoSaved = functionTwo; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -80,10 +85,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK(state.errors.empty()); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -101,11 +109,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK_EQ(1, state.errors.size()); - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -170,7 +179,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; - state.tryUnify(&variadicPack, &testPack); + state.tryUnify(&testPack, &variadicPack); CHECK(!state.errors.empty()); } @@ -180,7 +189,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; - state.tryUnify(&a, &b); + state.tryUnify(&b, &a); CHECK(state.errors.empty()); } @@ -214,32 +223,41 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +TEST_CASE("undo_new_prop_on_unsealed_table") { ScopedFastFlag flags[] = { {"LuauTableSubtypingVariance2", true}, + // This test makes no sense with a committing TxnLog. + {"LuauUseCommittingTxnLog", false}, }; // I am not sure how to make this happen in Luau code. - TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); - TypeId sealedTable = arena.addType(TableTypeVar{ - {{"prop", Property{getSingletonTypes().numberType}}}, - std::nullopt, - TypeLevel{}, - TableState::Sealed - }); + TryUnifyFixture fix; + + TypeId unsealedTable = fix.arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = + fix.arena.addType(TableTypeVar{{{"prop", Property{getSingletonTypes().numberType}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); const TableTypeVar* ttv = get(unsealedTable); REQUIRE(ttv); - state.tryUnify(unsealedTable, sealedTable); + fix.state.tryUnify(sealedTable, unsealedTable); // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. CHECK(!ttv->props.empty()); - state.log.rollback(); + fix.state.DEPRECATED_log.rollback(); CHECK(ttv->props.empty()); } +TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") +{ + TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + + ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); + CHECK(unifyErrors.size() == 0); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3f4420cda..5d37b032a 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -10,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -263,10 +265,13 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } -// CLI-45791 -TEST_CASE_FIXTURE(UnfrozenFixture, "type_pack_hidden_free_tail_infinite_growth") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_pack_hidden_free_tail_infinite_growth") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( --!nonstrict if _ then _[function(l0)end],l0 = _ @@ -278,7 +283,8 @@ elseif _ then end )"); - LUAU_REQUIRE_ERRORS(result); + // Switch back to LUAU_REQUIRE_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() > 0); } TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 2357869e9..b54ba9962 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -282,16 +283,19 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") +TEST_CASE("optional_union_follow") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local y: number? = 2 local x = y local function f(a: number, b: typeof(x), c: typeof(x)) return -a end return f() )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE_EQ(result.errors.size(), 1); + // LUAU_REQUIRE_ERROR_COUNT(1, result); auto acm = get(result.errors[0]); REQUIRE(acm); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 13db923ed..2e0d149ec 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -185,6 +185,8 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { + ScopedFastFlag sff{"LuauSealExports", true}; + TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; @@ -261,7 +263,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); - CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); + CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); } TEST_CASE("tagging_tables") diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 9cf3c7423..8c96ab335 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,13 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") +-- info linedefined & line +function testlinedefined() + local line = debug.info(1, "l") + local linedefined = debug.info(testlinedefined, "l") + assert(linedefined + 1 == line) +end + +testlinedefined() + return 'OK' diff --git a/tests/conformance/strconv.lua b/tests/conformance/strconv.lua new file mode 100644 index 000000000..85ad0295e --- /dev/null +++ b/tests/conformance/strconv.lua @@ -0,0 +1,51 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing string-number conversion") + +-- zero +assert(tostring(0) == "0") +assert(tostring(0/-1) == "-0") + +-- specials +assert(tostring(1/0) == "inf") +assert(tostring(-1/0) == "-inf") +assert(tostring(0/0) == "nan") + +-- integers +assert(tostring(1) == "1") +assert(tostring(42) == "42") +assert(tostring(-4294967296) == "-4294967296") +assert(tostring(9007199254740991) == "9007199254740991") + +-- decimals +assert(tostring(0.5) == "0.5") +assert(tostring(0.1) == "0.1") +assert(tostring(-0.17) == "-0.17") +assert(tostring(math.pi) == "3.141592653589793") + +-- fuzzing corpus +assert(tostring(5.4536123983019448e-311) == "5.453612398302e-311") +assert(tostring(5.4834368411298348e-311) == "5.48343684113e-311") +assert(tostring(4.4154895841930002e-305) == "4.415489584193e-305") +assert(tostring(1125968630513728) == "1125968630513728") +assert(tostring(3.3951932655938423e-313) == "3.3951932656e-313") +assert(tostring(1.625) == "1.625") +assert(tostring(4.9406564584124654e-324) == "5.e-324") +assert(tostring(2.0049288280105384) == "2.0049288280105384") +assert(tostring(3.0517578125e-05) == "0.000030517578125") +assert(tostring(1.383544921875) == "1.383544921875") +assert(tostring(3.0053350932691001) == "3.0053350932691") +assert(tostring(0.0001373291015625) == "0.0001373291015625") +assert(tostring(-1.9490628022799998e+289) == "-1.94906280228e+289") +assert(tostring(-0.00610404721867928) == "-0.00610404721867928") +assert(tostring(0.00014495849609375) == "0.00014495849609375") +assert(tostring(0.453125) == "0.453125") +assert(tostring(-4.2375343999999997e+73) == "-4.2375344e+73") +assert(tostring(1.3202313930270133e-192) == "1.3202313930270133e-192") +assert(tostring(3.6984408976312836e+19) == "36984408976312840000") +assert(tostring(2.0563000527063302) == "2.05630005270633") +assert(tostring(4.8970527433648997e-260) == "4.8970527433649e-260") +assert(tostring(1.62890625) == "1.62890625") +assert(tostring(1.1295093211933533e+65) == "1.1295093211933533e+65") + +return "OK" diff --git a/tests/main.cpp b/tests/main.cpp index ed17070c6..cd24e100f 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -2,6 +2,8 @@ #include "Luau/Common.h" #define DOCTEST_CONFIG_IMPLEMENT +// Our calls to parseOption/parseFlag don't provide a prefix so set the prefix to the empty string. +#define DOCTEST_CONFIG_OPTIONS_PREFIX "" #include "doctest.h" #ifdef _WIN32 @@ -18,6 +20,10 @@ #include +// Indicates if verbose output is enabled. +// Currently, this enables output from lua's 'print', but other verbose output could be enabled eventually. +bool verbose = false; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -46,7 +52,7 @@ static bool debuggerPresent() #endif } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) LUAU_DEBUGBREAK(); @@ -235,6 +241,11 @@ int main(int argc, char** argv) return 0; } + if (doctest::parseFlag(argc, argv, "--verbose")) + { + verbose = true; + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -261,7 +272,15 @@ int main(int argc, char** argv) } } - return context.run(); + int result = context.run(); + if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) + { + printf("Additional command line options:\n"); + printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); + printf(" --fflags= Sets specified fast flags\n"); + printf(" --list-fflags List all fast flags\n"); + } + return result; } diff --git a/tools/numprint.py b/tools/numprint.py new file mode 100644 index 000000000..47ad36d9c --- /dev/null +++ b/tools/numprint.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to generate power tables for Schubfach algorithm (see lnumprint.cpp) + +import math +import sys + +(_, pow10min, pow10max, compact) = sys.argv +pow10min = int(pow10min) +pow10max = int(pow10max) +compact = compact == "True" + +# extract high 128 bits of the value +def high128(n, roundup): + L = math.ceil(math.log2(n)) + + r = 0 + for i in range(L - 128, L): + if i >= 0 and (n & (1 << i)) != 0: + r |= (1 << (i - L + 128)) + + return r + (1 if roundup else 0) + +def pow10approx(n): + if n == 0: + return 1 << 127 + elif n > 0: + return high128(10**n, 5**n >= 2**128) + else: + # 10^-n is a binary fraction that can't be represented in floating point + # we need to extract top 128 bits of the fraction starting from the first 1 + # to get there, we need to divide 2^k by 10^n for a sufficiently large k and repeat the extraction process + p = 10**-n + k = 2**128 * 16**-n # this guarantees that the fraction has more than 128 extra bits + return high128(k // p, True) + +def pow5_64(n): + assert(n >= 0) + if n == 0: + return 1 << 63 + else: + return high128(5**n, False) >> 64 + +if not compact: + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1): + h = hex(pow10approx(p))[2:] + assert(len(h) == 32) + print(" {0x%s, 0x%s}," % (h[0:16].upper(), h[16:32].upper())) + print("}") +else: + print("// kPow5Table") + print("{") + for i in range(16): + print(" " + hex(pow5_64(i)) + ",") + print("}") + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1, 16): + base = pow10approx(p) + errw = 0 + for i in range(16): + real = pow10approx(p + i) + appr = (base * pow5_64(i)) >> 64 + scale = 1 if appr < (1 << 127) else 0 # 1-bit scale + + offset = (appr << scale) - real + assert(offset >= -4 and offset <= 3) # 3-bit offset + assert((appr << scale) >> 64 == real >> 64) # offset only affects low half + assert((appr << scale) - offset == real) # validate full reconstruction + + err = (scale << 3) | (offset + 4) + errw |= err << (i * 4) + + hbase = hex(base)[2:] + assert(len(hbase) == 32) + assert(errw < 1 << 64) + + print(" {0x%s, 0x%s, 0x%16x}," % (hbase[0:16], hbase[16:32], errw)) + print("}") From d189bd9b1a3c6ecc882c793c7382a1c886635900 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 17:37:50 -0800 Subject: [PATCH 13/14] Enable V2Read flag early --- VM/src/lvmload.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index cdb276c06..7839c68c2 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, true) LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens From 80d5c0000ee34767f8fec7ec24c02a438174a667 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 14 Jan 2022 08:06:31 -0800 Subject: [PATCH 14/14] Sync to upstream/release/510 --- Analysis/include/Luau/TypeInfer.h | 12 +- Analysis/include/Luau/TypeVar.h | 20 +- Analysis/src/Autocomplete.cpp | 79 +-- Analysis/src/Error.cpp | 12 +- Analysis/src/IostreamHelpers.cpp | 8 +- Analysis/src/JsonEncoder.cpp | 38 ++ Analysis/src/Module.cpp | 26 +- Analysis/src/ToString.cpp | 107 +++- Analysis/src/Transpiler.cpp | 65 +- Analysis/src/TypeAttach.cpp | 12 +- Analysis/src/TypeInfer.cpp | 310 ++++++++-- Analysis/src/TypeVar.cpp | 4 + Analysis/src/Unifier.cpp | 17 +- Ast/include/Luau/Ast.h | 49 +- Ast/include/Luau/Parser.h | 6 +- Ast/src/Ast.cpp | 32 +- Ast/src/Parser.cpp | 128 ++-- CLI/Analyze.cpp | 18 +- CLI/Repl.cpp | 131 ++-- Compiler/src/Builtins.cpp | 197 ++++++ Compiler/src/Builtins.h | 41 ++ Compiler/src/BytecodeBuilder.cpp | 6 +- Compiler/src/Compiler.cpp | 894 ++------------------------- Compiler/src/ConstantFolding.cpp | 394 ++++++++++++ Compiler/src/ConstantFolding.h | 48 ++ Compiler/src/TableShape.cpp | 129 ++++ Compiler/src/TableShape.h | 21 + Compiler/src/ValueTracking.cpp | 103 +++ Compiler/src/ValueTracking.h | 42 ++ Sources.cmake | 8 + VM/src/ldo.cpp | 12 +- VM/src/lfunc.cpp | 8 +- VM/src/lstrlib.cpp | 20 +- bench/tests/sunspider/3d-cube.lua | 32 +- tests/Autocomplete.test.cpp | 61 +- tests/Compiler.test.cpp | 93 +-- tests/Conformance.test.cpp | 8 - tests/Fixture.cpp | 5 +- tests/Fixture.h | 2 +- tests/Parser.test.cpp | 53 +- tests/ToString.test.cpp | 22 +- tests/Transpiler.test.cpp | 14 +- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 28 +- tests/TypeInfer.singletons.test.cpp | 26 +- tests/TypeInfer.tables.test.cpp | 101 ++- tests/TypeInfer.test.cpp | 150 ++++- tests/TypeInfer.typePacks.cpp | 324 ++++++++++ 48 files changed, 2649 insertions(+), 1269 deletions(-) create mode 100644 Compiler/src/Builtins.cpp create mode 100644 Compiler/src/Builtins.h create mode 100644 Compiler/src/ConstantFolding.cpp create mode 100644 Compiler/src/ConstantFolding.h create mode 100644 Compiler/src/TableShape.cpp create mode 100644 Compiler/src/TableShape.h create mode 100644 Compiler/src/ValueTracking.cpp create mode 100644 Compiler/src/ValueTracking.h diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 312283b05..aa0900140 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -97,6 +97,12 @@ struct ApplyTypeFunction : Substitution TypePackId clean(TypePackId tp) override; }; +struct GenericTypeDefinitions +{ + std::vector genericTypes; + std::vector genericPacks; +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -146,7 +152,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -336,8 +342,8 @@ struct TypeChecker const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d6e177142..fd2c2afa7 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -181,6 +181,18 @@ const T* get(const SingletonTypeVar* stv) return nullptr; } +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; +}; + struct FunctionArgument { Name name; @@ -358,8 +370,8 @@ struct ClassTypeVar struct TypeFun { // These should all be generic - std::vector typeParams; - std::vector typePackParams; + std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -369,13 +381,13 @@ struct TypeFun TypeId type; TypeFun() = default; - TypeFun(std::vector typeParams, TypeId type) + TypeFun(std::vector typeParams, TypeId type) : typeParams(std::move(typeParams)) , type(type) { } - TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) : typeParams(std::move(typeParams)) , typePackParams(std::move(typePackParams)) , type(type) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 67ebd0755..7a801f970 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,11 +13,10 @@ #include LUAU_FASTFLAG(LuauUseCommittingTxnLog) -LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); +LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -291,51 +290,23 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ expectedType = follow(*it); } - if (FFlag::LuauAutocompletePreferToCallFunctions) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - } - } - - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - } - else - { - if (canUnify(ty, expectedType)) - return TypeCorrectKind::Correct; - - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); - - if (!ftv) - return TypeCorrectKind::None; - auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty()) - return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; } - - return TypeCorrectKind::None; } + + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -435,13 +406,28 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) + if (FFlag::LuauMissingFollowACMetatables) { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } + } + else + { + if (get(indexIt->second.type) || get(indexIt->second.type)) + autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); + else if (auto indexFunction = get(indexIt->second.type)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } } } } @@ -1224,7 +1210,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (auto it = module.astTypes.find(node->asExpr())) autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } - else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) + else if (autocompleteIfElseExpression(node, ancestry, position, result)) return; else if (node->is()) return; @@ -1261,8 +1247,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - if (FFlag::LuauIfElseExpressionAnalysisSupport) - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index ce832c6b3..88069f1f5 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -190,24 +190,24 @@ struct ErrorConverter { name += "<"; bool first = true; - for (TypeId t : e.typeFun.typeParams) + for (auto param : e.typeFun.typeParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.ty); } - for (TypePackId t : e.typeFun.typePackParams) + for (auto param : e.typeFun.typePackParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.tp); } name += ">"; @@ -544,13 +544,13 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { - if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) + if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty) return false; } for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp) return false; } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 5bc76ade5..19c2ddabd 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -96,24 +96,24 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "<"; bool first = true; - for (TypeId t : error.typeFun.typeParams) + for (auto param : error.typeFun.typeParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.ty); } - for (TypePackId t : error.typeFun.typePackParams) + for (auto param : error.typeFun.typePackParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.tp); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 23491a5a1..8dd597e17 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,6 +5,8 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace Luau { @@ -337,6 +339,42 @@ struct AstJsonEncoder : public AstVisitor writeRaw("}"); } + void write(const AstGenericType& genericType) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericType.name); + if (genericType.defaultValue) + write("type", genericType.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericType.name); + } + } + + void write(const AstGenericTypePack& genericTypePack) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("type", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericTypePack.name); + } + } + void write(AstExprTable::Item::Kind kind) { switch (kind) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index cff85897c..9f352f4b3 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypeAliasDefaults) namespace Luau { @@ -447,11 +448,28 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; - for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 889dd6dc5..4b898d3a6 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) +LUAU_FASTFLAG(LuauTypeAliasDefaults) /* * Prefix generic typenames with gen- @@ -209,6 +210,14 @@ struct StringifierState result.name += s; } + + void emit(const char* s) + { + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + return; + + result.name += s; + } }; struct TypeVarStringifier @@ -280,13 +289,28 @@ struct TypeVarStringifier else first = false; - if (!singleTp) - state.emit("("); + if (FFlag::LuauTypeAliasDefaults) + { + bool wrap = !singleTp && get(follow(tp)); - stringify(tp); + if (wrap) + state.emit("("); - if (!singleTp) - state.emit(")"); + stringify(tp); + + if (wrap) + state.emit(")"); + } + else + { + if (!singleTp) + state.emit("("); + + stringify(tp); + + if (!singleTp) + state.emit(")"); + } } if (types.size() || typePacks.size()) @@ -1086,7 +1110,7 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) { std::string s = prefix; @@ -1175,6 +1199,77 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + if (!FFlag::LuauTypeAliasDefaults) + return toStringNamedFunction_DEPRECATED(prefix, ftv, opts); + + ToStringResult result; + StringifierState state(opts, result, opts.nameMap); + TypeVarStringifier tvs{state}; + + state.emit(prefix); + + if (!opts.hideNamedFunctionTypeParameters) + tvs.stringify(ftv.generics, ftv.genericPacks); + + state.emit("("); + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + state.emit(", "); + first = false; + + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); + ++argNameIter; + } + else + { + state.emit("_: "); + } + + tvs.stringify(*argPackIter); + ++argPackIter; + } + + if (argPackIter.tail()) + { + if (!first) + state.emit(", "); + + state.emit("...: "); + + if (auto vtp = get(*argPackIter.tail())) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } + + state.emit("): "); + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + + if (wrap) + state.emit("("); + + tvs.stringify(ftv.retType); + + if (wrap) + state.emit(")"); + + return result.name; +} + std::string dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 8e13ea5be..f59086834 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace { bool isIdentifierStartChar(char c) @@ -793,14 +795,47 @@ struct Printer for (auto o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } + } + else + { + writer.identifier(o.name.value); + } } for (auto o : a->genericPacks) { comma(); - writer.identifier(o.value); - writer.symbol("..."); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + else + { + writer.identifier(o.name.value); + writer.symbol("..."); + } } writer.symbol(">"); @@ -846,12 +881,20 @@ struct Printer for (const auto& o : func.generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); @@ -979,12 +1022,20 @@ struct Printer for (const auto& o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 9e61c7924..2ec020937 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -212,24 +212,24 @@ class TypeRehydrationVisitor if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName("")); - AstArray generics; + AstArray generics; generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { if (auto gtv = get(*it)) - generics.data[numGenerics++] = AstName(gtv->name.c_str()); + generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } - AstArray genericPacks; + AstArray genericPacks; genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { if (auto gtv = get(*it)) - genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); + genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1689a5c3d..bedcc0227 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -15,8 +15,8 @@ #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" -#include #include +#include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) @@ -24,25 +24,30 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau @@ -279,6 +284,14 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); } + if (FFlag::LuauPerModuleUnificationCache) + { + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.skipCacheForType.clear(); + } + return std::move(currentModule); } @@ -1213,18 +1226,18 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); - for (TypeId ty : binding->typeParams) + for (auto param : binding->typeParams) { - auto generic = get(ty); + auto generic = get(param.ty); LUAU_ASSERT(generic); - aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; } - for (TypePackId tp : binding->typePackParams) + for (auto param : binding->typePackParams) { - auto generic = get(tp); + auto generic = get(param.tp); LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; + aliasScope->privateTypePackBindings[generic->name] = param.tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1233,9 +1246,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the table is already named and we want to rename the type function, we have to bind new alias to a copy if (ttv->name) { + bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), + binding->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), + binding->typePackParams.begin(), binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // Copy can be skipped if this is an identical alias - if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - ttv->instantiatedTypePackParams != binding->typePackParams) + if (ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1243,8 +1264,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; - clone.instantiatedTypeParams = binding->typeParams; - clone.instantiatedTypePackParams = binding->typePackParams; + + for (auto param : binding->typeParams) + clone.instantiatedTypeParams.push_back(param.ty); + + for (auto param : binding->typePackParams) + clone.instantiatedTypePackParams.push_back(param.tp); ty = addType(std::move(clone)); } @@ -1252,8 +1277,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else { ttv->name = name; - ttv->instantiatedTypeParams = binding->typeParams; - ttv->instantiatedTypePackParams = binding->typePackParams; + + ttv->instantiatedTypeParams.clear(); + for (auto param : binding->typeParams) + ttv->instantiatedTypeParams.push_back(param.ty); + + ttv->instantiatedTypePackParams.clear(); + for (auto param : binding->typePackParams) + ttv->instantiatedTypePackParams.push_back(param.tp); } } else if (auto mtv = getMutable(follow(ty))) @@ -1367,9 +1398,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); + TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); FunctionTypeVar* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -1394,7 +1437,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& ExprResult result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr); + result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt); else if (expr.is()) result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) @@ -1438,21 +1481,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - { - if (FFlag::LuauIfElseExpressionAnalysisSupport) - { - result = checkExpr(scope, *a); - } - else - { - // Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse - // because we would generate an ICE. We also can't use the default value - // of result, because it will lead to a compiler crash. - // Note: LuauIfElseExpressionBaseSupport can be used to disable parser support - // for if-else expressions which will mean this node type is never created. - result = {anyType}; - } - } + result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt); else ice("Unhandled AstExpr?"); @@ -1895,7 +1924,7 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; + TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); @@ -2549,23 +2578,34 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { ExprResult result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); reportErrors(resolve(result.predicates, trueScope, true)); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr); + ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); + ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - unify(falseType.type, trueType.type, expr.location); + if (FFlag::LuauIfElseBranchTypeUnion) + { + if (falseType.type == trueType.type) + return {trueType.type}; + + std::vector types = reduceUnion({trueType.type, falseType.type}); + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; + } + else + { + unify(falseType.type, trueType.type, expr.location); - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. - return {trueType.type}; + // TODO: normalize(UnionTypeVar{{trueType, falseType}}) + // For now both trueType and falseType must be the same type. + return {trueType.type}; + } } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) @@ -3032,7 +3072,20 @@ std::pair TypeChecker::checkFunctionSignature( defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); - TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId funTy = + addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); FunctionTypeVar* ftv = getMutable(funTy); @@ -4848,11 +4901,38 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - if (!lit->hasParameterList && !tf->typePackParams.empty()) + bool hasDefaultTypes = false; + bool hasDefaultPacks = false; + bool parameterCountErrorReported = false; + + if (FFlag::LuauTypeAliasDefaults) { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); + hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!lit->hasParameterList) + { + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } + } + } + else + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } } std::vector typeParams; @@ -4892,14 +4972,89 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); + if (FFlag::LuauTypeAliasDefaults) + { + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); + + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); + + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) + { + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.typeArguments.clear(); + applyTypeFunction.typePackArguments.clear(); + applyTypeFunction.currentModule = currentModule; + applyTypeFunction.level = scope->level; + applyTypeFunction.encounteredForwardedType = false; + + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; + + if (typesProvided < typesRequired) + { + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); + + if (!defaultTy) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); + } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); + } + } + + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); + + if (!defaultTp) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); + } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); + } + } + } + } + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) typePackParams.push_back(addTypePack({})); if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + if (!parameterCountErrorReported) + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { @@ -4913,11 +5068,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + if (FFlag::LuauRecursiveTypeParameterRestriction) { + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // If the generic parameters and the type arguments are the same, we are about to // perform an identity substitution, which we can just short-circuit. - return tf->type; + if (sameTys && sameTps) + return tf->type; } return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); @@ -4948,7 +5112,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); - TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); FunctionTypeVar* ftv = getMutable(fnType); @@ -5137,11 +5313,11 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; applyTypeFunction.typePackArguments.clear(); for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -5213,17 +5389,23 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; - std::vector generics; - for (const AstName& generic : genericNames) + std::vector generics; + + for (const AstGenericType& generic : genericNames) { - Name n = generic.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && generic.defaultValue) + defaultValue = resolveType(scope, *generic.defaultValue); + + Name n = generic.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5246,14 +5428,20 @@ std::pair, std::vector> TypeChecker::createGener g = addType(Unifiable::Generic{level, n}); } - generics.push_back(g); + generics.push_back({g, defaultValue}); scope->privateTypeBindings[n] = TypeFun{{}, g}; } - std::vector genericPacks; - for (const AstName& genericPack : genericPackNames) + std::vector genericPacks; + + for (const AstGenericTypePack& genericPack : genericPackNames) { - Name n = genericPack.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue) + defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + + Name n = genericPack.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5276,7 +5464,7 @@ std::pair, std::vector> TypeChecker::createGener g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } - genericPacks.push_back(g); + genericPacks.push_back({g, defaultValue}); scope->privateTypePackBindings[n] = g; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4cab79c8a..ac2b25410 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(DebugLuauFreezeArena) @@ -453,6 +454,9 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { + if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs)) + return true; + return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 393a84a70..6873c657c 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) @@ -1170,9 +1171,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); @@ -1370,9 +1377,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 5b4bfa033..573850a5a 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -334,6 +334,20 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName +struct AstGenericType +{ + AstName name; + Location location; + AstType* defaultValue = nullptr; +}; + +struct AstGenericTypePack +{ + AstName name; + Location location; + AstTypePack* defaultValue = nullptr; +}; + extern int gAstRttiIndex; template @@ -569,15 +583,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, + AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, std::optional argLocation = std::nullopt); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; bool hasReturnAnnotation; @@ -942,14 +956,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -972,14 +986,15 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; @@ -1077,13 +1092,13 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, - const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 87ebc48b5..40ecdcdd5 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -219,7 +219,7 @@ class Parser AstTableIndexer* parseTableIndexerAnnotation(); AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); @@ -281,7 +281,7 @@ class Parser Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(); + std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); @@ -418,6 +418,8 @@ class Parser std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; std::string scratchData; }; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index e709894d9..9b5bc0c71 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -158,9 +158,10 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, + std::optional argLocation) : AstExpr(ClassIndex(), location) , generics(generics) , genericPacks(genericPacks) @@ -641,8 +642,8 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -655,7 +656,21 @@ AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) + { + for (const AstGenericType& el : generics) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + + for (const AstGenericTypePack& el : genericPacks) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + type->visit(visitor); + } } AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) @@ -671,8 +686,9 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -778,7 +794,7 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) , generics(generics) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 72f616497..77787cb1c 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,10 +10,10 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) +LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) namespace Luau { @@ -394,23 +394,13 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauIfStatementRecursionGuard) - { - unsigned int recursionCounterOld = recursionCounter; - incrementRecursionCounter("elseif"); - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; - } - else - { - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - } + unsigned int recursionCounterOld = recursionCounter; + incrementRecursionCounter("elseif"); + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + recursionCounter = recursionCounterOld; } else { @@ -772,7 +762,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults); expectAndConsume('=', "type alias"); @@ -788,8 +778,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; @@ -849,7 +839,7 @@ AstStat* Parser::parseDeclaration(const Location& start) nextLexeme(); Name globalName = parseName("global function name"); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); @@ -991,7 +981,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1228,8 +1218,8 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; + AstArray generics{nullptr, 0}; + AstArray genericPacks{nullptr, 0}; AstArray types = copy(result); AstArray> names = copy(resultNames); @@ -1363,7 +1353,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme begin = lexer.current(); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme parameterStart = lexer.current(); @@ -1401,7 +1391,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, +AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) { @@ -1448,7 +1438,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1461,7 +1451,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } else @@ -1498,7 +1488,7 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() TempVector parts(scratchAnnotation); - auto [type, typePack] = parseSimpleTypeAnnotation(true); + auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); if (typePack) { @@ -1521,7 +1511,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); recursionCounter = oldRecursionCount; @@ -2121,7 +2111,7 @@ AstExpr* Parser::parseSimpleExpr() { return parseTableConstructor(); } - else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) + else if (lexer.current().type == Lexeme::ReservedIf) { return parseIfElseExpr(); } @@ -2341,10 +2331,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeList() +std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) { - TempVector names{scratchName}; - TempVector namePacks{scratchPackName}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { @@ -2352,21 +2342,73 @@ std::pair, AstArray> Parser::parseGenericTypeList() nextLexeme(); bool seenPack = false; + bool seenDefault = false; + while (true) { + Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack)) { seenPack = true; - nextLexeme(); - namePacks.push_back(name); + + if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3) + report(lexer.current().location, "Generic types come before generic type packs"); + else + nextLexeme(); + + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + Lexeme packBegin = lexer.current(); + + if (shouldParseTypePackAnnotation(lexer)) + { + auto typePack = parseTypePackAnnotation(); + + namePacks.push_back({name, nameLocation, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type pack after type pack name"); + + namePacks.push_back({name, nameLocation, nullptr}); + } } else { - if (seenPack) + if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack) report(lexer.current().location, "Generic types come before generic type packs"); - names.push_back(name); + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + AstType* defaultType = parseTypeAnnotation(); + + names.push_back({name, nameLocation, defaultType}); + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type after type name"); + + names.push_back({name, nameLocation, nullptr}); + } } if (lexer.current().type == ',') @@ -2378,8 +2420,8 @@ std::pair, AstArray> Parser::parseGenericTypeList() expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 54a9a26fa..e0dc3e0fe 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,8 @@ #include "FileUtils.h" +LUAU_FASTFLAG(DebugLuauTimeTracing) + enum class ReportFormat { Default, @@ -105,6 +107,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -213,7 +216,17 @@ int main(int argc, char** argv) format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; + else if (strcmp(argv[i], "--timetrace") == 0) + FFlag::DebugLuauTimeTracing.value = true; + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; } +#endif Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; @@ -240,7 +253,10 @@ int main(int argc, char** argv) fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); } - return (format == ReportFormat::Luacheck) ? 0 : failed; + if (format == ReportFormat::Luacheck) + return 0; + else + return failed ? 1 : 0; } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 26d4333a9..36747f487 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -19,6 +19,16 @@ #include #endif +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CliMode +{ + Unknown, + Repl, + Compile, + RunSourceFiles +}; + enum class CompileFormat { Text, @@ -485,8 +495,10 @@ static void displayHelp(const char* argv0) printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -503,71 +515,112 @@ int main(int argc, char** argv) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = true; - if (argc == 1) + CliMode mode = CliMode::Unknown; + CompileFormat compileFormat{}; + int profile = 0; + bool coverage = false; + + // Set the mode if the user has explicitly specified one. + int argStart = 1; + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - runRepl(); - return 0; + argStart++; + mode = CliMode::Compile; + if (strcmp(argv[1], "--compile") == 0) + { + compileFormat = CompileFormat::Text; + } + else if (strcmp(argv[1], "--compile=binary") == 0) + { + compileFormat = CompileFormat::Binary; + } + else if (strcmp(argv[1], "--compile=text") == 0) + { + compileFormat = CompileFormat::Text; + } + else + { + fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n"); + return -1; + } } - if (argc >= 2 && strcmp(argv[1], "--help") == 0) + for (int i = argStart; i < argc; i++) { - displayHelp(argv[0]); - return 0; - } + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strcmp(argv[i], "--profile") == 0) + { + profile = 10000; // default to 10 KHz + } + else if (strncmp(argv[i], "--profile=", 10) == 0) + { + profile = atoi(argv[i] + 10); + } + else if (strcmp(argv[i], "--coverage") == 0) + { + coverage = true; + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; +#if !defined(LUAU_ENABLE_TIME_TRACE) + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; +#endif + } + else if (argv[i][0] == '-') + { + fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } - if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) + const std::vector files = getSourceFiles(argc, argv); + if (mode == CliMode::Unknown) { - CompileFormat format = CompileFormat::Text; - - if (strcmp(argv[1], "--compile=binary") == 0) - format = CompileFormat::Binary; + mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; + } + switch (mode) + { + case CliMode::Compile: + { #ifdef _WIN32 - if (format == CompileFormat::Binary) + if (compileFormat == CompileFormat::Binary) _setmode(_fileno(stdout), _O_BINARY); #endif - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), format); + failed += !compileFile(path.c_str(), compileFormat); - return failed; + return failed ? 1 : 0; } - + case CliMode::Repl: + { + runRepl(); + return 0; + } + case CliMode::RunSourceFiles: { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); setupState(L); - int profile = 0; - bool coverage = false; - - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] != '-') - continue; - - if (strcmp(argv[i], "--profile") == 0) - profile = 10000; // default to 10 KHz - else if (strncmp(argv[i], "--profile=", 10) == 0) - profile = atoi(argv[i] + 10); - else if (strcmp(argv[i], "--coverage") == 0) - coverage = true; - } - if (profile) profilerStart(L, profile); if (coverage) coverageInit(L); - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) @@ -582,7 +635,11 @@ int main(int argc, char** argv) if (coverage) coverageDump("coverage.out"); - return failed; + return failed ? 1 : 0; + } + case CliMode::Unknown: + default: + LUAU_ASSERT(!"Unhandled cli mode."); } } diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp new file mode 100644 index 000000000..e344eb917 --- /dev/null +++ b/Compiler/src/Builtins.cpp @@ -0,0 +1,197 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Builtins.h" + +#include "Luau/Bytecode.h" +#include "Luau/Compiler.h" + +namespace Luau +{ +namespace Compile +{ + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables) +{ + if (AstExprLocal* expr = node->as()) + { + const Variable* v = variables.find(expr->local); + + return v && !v->written && v->init ? getBuiltin(v->init, globals, variables) : Builtin(); + } + else if (AstExprIndexName* expr = node->as()) + { + if (AstExprGlobal* object = expr->expr->as()) + { + return getGlobalState(globals, object->name) == Global::Default ? Builtin{object->name, expr->index} : Builtin(); + } + else + { + return Builtin(); + } + } + else if (AstExprGlobal* expr = node->as()) + { + return getGlobalState(globals, expr->name) == Global::Default ? Builtin{AstName(), expr->name} : Builtin(); + } + else + { + return Builtin(); + } +} + +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) +{ + if (builtin.empty()) + return -1; + + if (builtin.isGlobal("assert")) + return LBF_ASSERT; + + if (builtin.isGlobal("type")) + return LBF_TYPE; + + if (builtin.isGlobal("typeof")) + return LBF_TYPEOF; + + if (builtin.isGlobal("rawset")) + return LBF_RAWSET; + if (builtin.isGlobal("rawget")) + return LBF_RAWGET; + if (builtin.isGlobal("rawequal")) + return LBF_RAWEQUAL; + + if (builtin.isGlobal("unpack")) + return LBF_TABLE_UNPACK; + + if (builtin.object == "math") + { + if (builtin.method == "abs") + return LBF_MATH_ABS; + if (builtin.method == "acos") + return LBF_MATH_ACOS; + if (builtin.method == "asin") + return LBF_MATH_ASIN; + if (builtin.method == "atan2") + return LBF_MATH_ATAN2; + if (builtin.method == "atan") + return LBF_MATH_ATAN; + if (builtin.method == "ceil") + return LBF_MATH_CEIL; + if (builtin.method == "cosh") + return LBF_MATH_COSH; + if (builtin.method == "cos") + return LBF_MATH_COS; + if (builtin.method == "deg") + return LBF_MATH_DEG; + if (builtin.method == "exp") + return LBF_MATH_EXP; + if (builtin.method == "floor") + return LBF_MATH_FLOOR; + if (builtin.method == "fmod") + return LBF_MATH_FMOD; + if (builtin.method == "frexp") + return LBF_MATH_FREXP; + if (builtin.method == "ldexp") + return LBF_MATH_LDEXP; + if (builtin.method == "log10") + return LBF_MATH_LOG10; + if (builtin.method == "log") + return LBF_MATH_LOG; + if (builtin.method == "max") + return LBF_MATH_MAX; + if (builtin.method == "min") + return LBF_MATH_MIN; + if (builtin.method == "modf") + return LBF_MATH_MODF; + if (builtin.method == "pow") + return LBF_MATH_POW; + if (builtin.method == "rad") + return LBF_MATH_RAD; + if (builtin.method == "sinh") + return LBF_MATH_SINH; + if (builtin.method == "sin") + return LBF_MATH_SIN; + if (builtin.method == "sqrt") + return LBF_MATH_SQRT; + if (builtin.method == "tanh") + return LBF_MATH_TANH; + if (builtin.method == "tan") + return LBF_MATH_TAN; + if (builtin.method == "clamp") + return LBF_MATH_CLAMP; + if (builtin.method == "sign") + return LBF_MATH_SIGN; + if (builtin.method == "round") + return LBF_MATH_ROUND; + } + + if (builtin.object == "bit32") + { + if (builtin.method == "arshift") + return LBF_BIT32_ARSHIFT; + if (builtin.method == "band") + return LBF_BIT32_BAND; + if (builtin.method == "bnot") + return LBF_BIT32_BNOT; + if (builtin.method == "bor") + return LBF_BIT32_BOR; + if (builtin.method == "bxor") + return LBF_BIT32_BXOR; + if (builtin.method == "btest") + return LBF_BIT32_BTEST; + if (builtin.method == "extract") + return LBF_BIT32_EXTRACT; + if (builtin.method == "lrotate") + return LBF_BIT32_LROTATE; + if (builtin.method == "lshift") + return LBF_BIT32_LSHIFT; + if (builtin.method == "replace") + return LBF_BIT32_REPLACE; + if (builtin.method == "rrotate") + return LBF_BIT32_RROTATE; + if (builtin.method == "rshift") + return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz") + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz") + return LBF_BIT32_COUNTRZ; + } + + if (builtin.object == "string") + { + if (builtin.method == "byte") + return LBF_STRING_BYTE; + if (builtin.method == "char") + return LBF_STRING_CHAR; + if (builtin.method == "len") + return LBF_STRING_LEN; + if (builtin.method == "sub") + return LBF_STRING_SUB; + } + + if (builtin.object == "table") + { + if (builtin.method == "insert") + return LBF_TABLE_INSERT; + if (builtin.method == "unpack") + return LBF_TABLE_UNPACK; + } + + if (options.vectorCtor) + { + if (options.vectorLib) + { + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) + return LBF_VECTOR; + } + else + { + if (builtin.isGlobal(options.vectorCtor)) + return LBF_VECTOR; + } + } + + return -1; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h new file mode 100644 index 000000000..60df53a18 --- /dev/null +++ b/Compiler/src/Builtins.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +struct CompileOptions; +} + +namespace Luau +{ +namespace Compile +{ + +struct Builtin +{ + AstName object; + AstName method; + + bool empty() const + { + return object == AstName() && method == AstName(); + } + + bool isGlobal(const char* name) const + { + return object == AstName() && method == name; + } + + bool isMethod(const char* table, const char* name) const + { + return object == table && method == name; + } +}; + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables); +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 2d31c409c..e6d024546 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -714,7 +714,7 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const // third pass: write resulting data int logspan = log2(span); - writeByte(ss, logspan); + writeByte(ss, uint8_t(logspan)); uint8_t lastOffset = 0; @@ -723,8 +723,8 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const int delta = lines[i] - baseline[i >> logspan]; LUAU_ASSERT(delta >= 0 && delta <= 255); - writeByte(ss, delta - lastOffset); - lastOffset = delta; + writeByte(ss, uint8_t(delta) - lastOffset); + lastOffset = uint8_t(delta); } int lastLine = 0; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6ae490273..9758c4a9a 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -6,15 +6,20 @@ #include "Luau/Common.h" #include "Luau/TimeTrace.h" +#include "Builtins.h" +#include "ConstantFolding.h" +#include "TableShape.h" +#include "ValueTracking.h" + #include #include #include -LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) - namespace Luau { +using namespace Luau::Compile; + static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; @@ -62,7 +67,6 @@ static BytecodeBuilder::StringRef sref(AstArray data) struct Compiler { - struct Constant; struct RegScope; Compiler(BytecodeBuilder& bytecode, const CompileOptions& options) @@ -71,8 +75,9 @@ struct Compiler , functions(nullptr) , locals(nullptr) , globals(AstName()) + , variables(nullptr) , constants(nullptr) - , predictedTableSize(nullptr) + , tableShapes(nullptr) { } @@ -96,8 +101,10 @@ struct Compiler local->location, "Out of upvalue registers when trying to allocate %s: exceeded limit %d", local->name.value, kMaxUpvalueCount); // mark local as captured so that closeLocals emits LOP_CLOSEUPVALS accordingly - Local& l = locals[local]; - l.captured = true; + Variable* v = variables.find(local); + + if (v && v->written) + locals[local].captured = true; upvals.push_back(local); @@ -273,8 +280,8 @@ struct Compiler if (options.optimizationLevel >= 1) { - Builtin builtin = getBuiltin(expr->func); - bfid = getBuiltinFunctionId(builtin); + Builtin builtin = getBuiltin(expr->func, globals, variables); + bfid = getBuiltinFunctionId(builtin, options); } if (expr->self) @@ -364,12 +371,12 @@ struct Compiler else { args[i] = uint8_t(regs + 1 + i); - compileExprTempTop(expr->args.data[i], args[i]); + compileExprTempTop(expr->args.data[i], uint8_t(args[i])); } } fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(opc, uint8_t(bfid), args[0], 0); + bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); if (opc != LOP_FASTCALL1) bytecode.emitAux(args[1]); @@ -385,7 +392,7 @@ struct Compiler } if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), args[i], 0); + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); } } else @@ -424,8 +431,10 @@ struct Compiler for (AstLocal* uv : f->upvals) { - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); + Variable* ul = variables.find(uv); + + if (!ul) + return false; if (ul->written) return false; @@ -437,10 +446,11 @@ struct Compiler // this will only deoptimize (outside of fenv changes) if top level code is executed twice with different results. if (uv->functionDepth != 0 || uv->loopDepth != 0) { - if (!ul->func) + AstExprFunction* uf = ul->init ? ul->init->as() : nullptr; + if (!uf) return false; - if (ul->func != func && !shouldShareClosure(ul->func)) + if (uf != func && !shouldShareClosure(uf)) return false; } } @@ -471,7 +481,7 @@ struct Compiler if (cid >= 0 && cid < 32768) { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + bytecode.emitAD(LOP_DUPCLOSURE, target, int16_t(cid)); shared = true; } } @@ -483,17 +493,15 @@ struct Compiler { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); - - bool immutable = !ul->written; + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; if (uv->functionDepth == expr->functionDepth - 1) { // get local variable uint8_t reg = getLocal(uv); - bytecode.emitABC(LOP_CAPTURE, immutable ? LCT_VAL : LCT_REF, reg, 0); + bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), reg, 0); } else { @@ -635,7 +643,7 @@ struct Compiler if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) { - bytecode.emitAD(opc, rr, 0); + bytecode.emitAD(opc, uint8_t(rr), 0); bytecode.emitAux(rl); } else @@ -687,7 +695,7 @@ struct Compiler break; case Constant::Type_String: - cid = bytecode.addConstantString(sref(c->valueString)); + cid = bytecode.addConstantString(sref(c->getString())); break; default: @@ -1066,10 +1074,10 @@ struct Compiler // Optimization: if the table is empty, we can compute it directly into the target if (expr->items.size == 0) { - auto [hashSize, arraySize] = predictedTableSize[expr]; + TableShape shape = tableShapes[expr]; - bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(hashSize), 0); - bytecode.emitAux(arraySize); + bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(shape.hashSize), 0); + bytecode.emitAux(shape.arraySize); return; } @@ -1144,7 +1152,7 @@ struct Compiler } else { - bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitABC(LOP_NEWTABLE, reg, uint8_t(encodedHashSize), 0); bytecode.emitAux(0); } } @@ -1157,7 +1165,7 @@ struct Compiler bool trailingVarargs = last && last->kind == AstExprTable::Item::List && last->value->is(); LUAU_ASSERT(!trailingVarargs || arraySize > 0); - bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitABC(LOP_NEWTABLE, reg, uint8_t(encodedHashSize), 0); bytecode.emitAux(arraySize - trailingVarargs + indexSize); } @@ -1252,16 +1260,12 @@ struct Compiler bool canImport(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || !global->written); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) != Global::Written; } bool canImportChain(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) == Global::Default; } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -1341,7 +1345,7 @@ struct Compiler { uint8_t rt = compileExprAuto(expr->expr, rs); - BytecodeBuilder::StringRef iname = sref(cv->valueString); + BytecodeBuilder::StringRef iname = sref(cv->getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); @@ -1427,7 +1431,7 @@ struct Compiler case Constant::Type_String: { - int32_t cid = bytecode.addConstantString(sref(cv->valueString)); + int32_t cid = bytecode.addConstantString(sref(cv->getString())); if (cid < 0) CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); @@ -1546,7 +1550,7 @@ struct Compiler { compileExpr(expr->expr, target, targetTemp); } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) + else if (AstExprIfElse* expr = node->as()) { compileExprIfElse(expr, target, targetTemp); } @@ -1711,7 +1715,7 @@ struct Compiler { LValue result = {LValue::Kind_IndexName}; result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->valueString); + result.name = sref(cv->getString()); result.location = node->location; return result; @@ -1796,9 +1800,8 @@ struct Compiler return false; Local* l = locals.find(le->local); - LUAU_ASSERT(l); - return l->allocated; + return l && l->allocated; } bool isStatBreak(AstStat* node) @@ -2040,9 +2043,9 @@ struct Compiler for (AstLocal* local : stat->vars) { - Local* l = locals.find(local); + Variable* v = variables.find(local); - if (!l || l->constant.type == Constant::Type_Unknown) + if (!v || !v->constant) return false; } @@ -2082,9 +2085,7 @@ struct Compiler // through) uint8_t varreg = regs + 2; - Local* il = locals.find(stat->var); - - if (il && il->written) + if (Variable* il = variables.find(stat->var); il && il->written) varreg = allocReg(stat, 1); compileExprTemp(stat->from, uint8_t(regs + 2)); @@ -2164,7 +2165,7 @@ struct Compiler { if (stat->values.size == 1 && stat->values.data[0]->is()) { - Builtin builtin = getBuiltin(stat->values.data[0]->as()->func); + Builtin builtin = getBuiltin(stat->values.data[0]->as()->func, globals, variables); if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { @@ -2179,7 +2180,7 @@ struct Compiler } else if (stat->values.size == 2) { - Builtin builtin = getBuiltin(stat->values.data[0]); + Builtin builtin = getBuiltin(stat->values.data[0], globals, variables); if (builtin.isGlobal("next")) // for .. in next,t { @@ -2594,7 +2595,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) return true; } @@ -2613,7 +2614,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) { captured = true; captureReg = std::min(captureReg, l->reg); @@ -2728,519 +2729,6 @@ struct Compiler return !node->is() && !node->is(); } - struct AssignmentVisitor : AstVisitor - { - struct Hasher - { - size_t operator()(const std::pair& p) const - { - return std::hash()(p.first) ^ std::hash()(p.second); - } - }; - - DenseHashMap localToTable; - DenseHashSet, Hasher> fields; - - AssignmentVisitor(Compiler* self) - : localToTable(nullptr) - , fields(std::pair()) - , self(self) - { - } - - void assignField(AstExpr* expr, AstName index) - { - if (AstExprLocal* lv = expr->as()) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - std::pair field = {*table, index}; - - if (!fields.contains(field)) - { - fields.insert(field); - self->predictedTableSize[*table].first += 1; - } - } - } - } - - void assignField(AstExpr* expr, AstExpr* index) - { - AstExprLocal* lv = expr->as(); - AstExprConstantNumber* number = index->as(); - - if (lv && number) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - unsigned int& arraySize = self->predictedTableSize[*table].second; - - if (number->value == double(arraySize + 1)) - arraySize += 1; - } - } - } - - void assign(AstExpr* var) - { - if (AstExprLocal* lv = var->as()) - { - self->locals[lv->local].written = true; - } - else if (AstExprGlobal* gv = var->as()) - { - self->globals[gv->name].written = true; - } - else if (AstExprIndexName* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else if (AstExprIndexExpr* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else - { - // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 - var->visit(this); - } - } - - AstExprTable* getTableHint(AstExpr* expr) - { - // unadorned table literal - if (AstExprTable* table = expr->as()) - return table; - - // setmetatable(table literal, ...) - if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) - if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") - if (AstExprTable* table = call->args.data[0]->as()) - return table; - - return nullptr; - } - - bool visit(AstStatLocal* node) override - { - // track local -> table association so that we can update table size prediction in assignField - if (node->vars.size == 1 && node->values.size == 1) - if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) - localToTable[node->vars.data[0]] = table; - - return true; - } - - bool visit(AstStatAssign* node) override - { - for (size_t i = 0; i < node->vars.size; ++i) - assign(node->vars.data[i]); - - for (size_t i = 0; i < node->values.size; ++i) - node->values.data[i]->visit(this); - - return false; - } - - bool visit(AstStatCompoundAssign* node) override - { - assign(node->var); - node->value->visit(this); - - return false; - } - - bool visit(AstStatFunction* node) override - { - assign(node->name); - node->func->visit(this); - - return false; - } - - Compiler* self; - }; - - struct ConstantVisitor : AstVisitor - { - ConstantVisitor(Compiler* self) - : self(self) - { - } - - void analyzeUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) - { - switch (op) - { - case AstExprUnary::Not: - if (arg.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !arg.isTruthful(); - } - break; - - case AstExprUnary::Minus: - if (arg.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = -arg.valueNumber; - } - break; - - case AstExprUnary::Len: - if (arg.type == Constant::Type_String) - { - result.type = Constant::Type_Number; - result.valueNumber = double(arg.valueString.size); - } - break; - - default: - LUAU_ASSERT(!"Unexpected unary operation"); - } - } - - bool constantsEqual(const Constant& la, const Constant& ra) - { - LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); - - switch (la.type) - { - case Constant::Type_Nil: - return ra.type == Constant::Type_Nil; - - case Constant::Type_Boolean: - return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; - - case Constant::Type_Number: - return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; - - case Constant::Type_String: - return ra.type == Constant::Type_String && la.valueString.size == ra.valueString.size && - memcmp(la.valueString.data, ra.valueString.data, la.valueString.size) == 0; - - default: - LUAU_ASSERT(!"Unexpected constant type in comparison"); - return false; - } - } - - void analyzeBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) - { - switch (op) - { - case AstExprBinary::Add: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber + ra.valueNumber; - } - break; - - case AstExprBinary::Sub: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - ra.valueNumber; - } - break; - - case AstExprBinary::Mul: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber * ra.valueNumber; - } - break; - - case AstExprBinary::Div: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber / ra.valueNumber; - } - break; - - case AstExprBinary::Mod: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; - } - break; - - case AstExprBinary::Pow: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = pow(la.valueNumber, ra.valueNumber); - } - break; - - case AstExprBinary::Concat: - break; - - case AstExprBinary::CompareNe: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareEq: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareLt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber < ra.valueNumber; - } - break; - - case AstExprBinary::CompareLe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber <= ra.valueNumber; - } - break; - - case AstExprBinary::CompareGt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber > ra.valueNumber; - } - break; - - case AstExprBinary::CompareGe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber >= ra.valueNumber; - } - break; - - case AstExprBinary::And: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? ra : la; - } - break; - - case AstExprBinary::Or: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? la : ra; - } - break; - - default: - LUAU_ASSERT(!"Unexpected binary operation"); - } - } - - Constant analyze(AstExpr* node) - { - Constant result; - result.type = Constant::Type_Unknown; - - if (AstExprGroup* expr = node->as()) - { - result = analyze(expr->expr); - } - else if (node->is()) - { - result.type = Constant::Type_Nil; - } - else if (AstExprConstantBool* expr = node->as()) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = expr->value; - } - else if (AstExprConstantNumber* expr = node->as()) - { - result.type = Constant::Type_Number; - result.valueNumber = expr->value; - } - else if (AstExprConstantString* expr = node->as()) - { - result.type = Constant::Type_String; - result.valueString = expr->value; - } - else if (AstExprLocal* expr = node->as()) - { - const Local* l = self->locals.find(expr->local); - - if (l && l->constant.type != Constant::Type_Unknown) - { - LUAU_ASSERT(!l->written); - result = l->constant; - } - } - else if (node->is()) - { - // nope - } - else if (node->is()) - { - // nope - } - else if (AstExprCall* expr = node->as()) - { - analyze(expr->func); - - for (size_t i = 0; i < expr->args.size; ++i) - analyze(expr->args.data[i]); - } - else if (AstExprIndexName* expr = node->as()) - { - analyze(expr->expr); - } - else if (AstExprIndexExpr* expr = node->as()) - { - analyze(expr->expr); - analyze(expr->index); - } - else if (AstExprFunction* expr = node->as()) - { - // this is necessary to propagate constant information in all child functions - expr->body->visit(this); - } - else if (AstExprTable* expr = node->as()) - { - for (size_t i = 0; i < expr->items.size; ++i) - { - const AstExprTable::Item& item = expr->items.data[i]; - - if (item.key) - analyze(item.key); - - analyze(item.value); - } - } - else if (AstExprUnary* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - analyzeUnary(result, expr->op, arg); - } - else if (AstExprBinary* expr = node->as()) - { - Constant la = analyze(expr->left); - Constant ra = analyze(expr->right); - - analyzeBinary(result, expr->op, la, ra); - } - else if (AstExprTypeAssertion* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - result = arg; - } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) - { - Constant cond = analyze(expr->condition); - Constant trueExpr = analyze(expr->trueExpr); - Constant falseExpr = analyze(expr->falseExpr); - if (cond.type != Constant::Type_Unknown) - { - result = cond.isTruthful() ? trueExpr : falseExpr; - } - } - else - { - LUAU_ASSERT(!"Unknown expression type"); - } - - if (result.type != Constant::Type_Unknown) - self->constants[node] = result; - - return result; - } - - bool visit(AstExpr* node) override - { - // note: we short-circuit the visitor traversal through any expression trees by returning false - // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression - analyze(node); - - return false; - } - - bool visit(AstStatLocal* node) override - { - // for values that match 1-1 we record the initializing expression for future analysis - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Local& l = self->locals[node->vars.data[i]]; - - l.init = node->values.data[i]; - } - - // all values that align wrt indexing are simple - we just match them 1-1 - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Constant arg = analyze(node->values.data[i]); - - if (arg.type != Constant::Type_Unknown) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - l.constant = arg; - } - } - - if (node->vars.size > node->values.size) - { - // if we have trailing variables, then depending on whether the last value is capable of returning multiple values - // (aka call or varargs), we either don't know anything about these vars, or we know they're nil - AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; - bool multRet = last && (last->is() || last->is()); - - for (size_t i = node->values.size; i < node->vars.size; ++i) - { - if (!multRet) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - { - l.constant.type = Constant::Type_Nil; - } - } - } - } - else - { - // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside - // them - for (size_t i = node->vars.size; i < node->values.size; ++i) - analyze(node->values.data[i]); - } - - return false; - } - - Compiler* self; - }; - struct FenvVisitor : AstVisitor { bool& getfenvUsed; @@ -3283,14 +2771,6 @@ struct Compiler return false; } - - bool visit(AstStatLocalFunction* node) override - { - // record local->function association for some optimizations - self->locals[node->name].func = node->func; - - return true; - } }; struct UndefinedLocalVisitor : AstVisitor @@ -3397,70 +2877,12 @@ struct Compiler std::vector upvals; }; - struct Constant - { - enum Type - { - Type_Unknown, - Type_Nil, - Type_Boolean, - Type_Number, - Type_String, - }; - - Type type = Type_Unknown; - - union - { - bool valueBoolean; - double valueNumber; - AstArray valueString = {}; - }; - - bool isTruthful() const - { - LUAU_ASSERT(type != Type_Unknown); - return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); - } - }; - struct Local { uint8_t reg = 0; bool allocated = false; bool captured = false; - bool written = false; - AstExpr* init = nullptr; uint32_t debugpc = 0; - Constant constant; - AstExprFunction* func = nullptr; - }; - - struct Global - { - bool writable = false; - bool written = false; - }; - - struct Builtin - { - AstName object; - AstName method; - - bool empty() const - { - return object == AstName() && method == AstName(); - } - - bool isGlobal(const char* name) const - { - return object == AstName() && method == name; - } - - bool isMethod(const char* table, const char* name) const - { - return object == table && method == name; - } }; struct LoopJump @@ -3482,194 +2904,6 @@ struct Compiler AstExpr* untilCondition; }; - Builtin getBuiltin(AstExpr* node) - { - if (AstExprLocal* expr = node->as()) - { - Local* l = locals.find(expr->local); - - return l && !l->written && l->init ? getBuiltin(l->init) : Builtin(); - } - else if (AstExprIndexName* expr = node->as()) - { - if (AstExprGlobal* object = expr->expr->as()) - { - Global* g = globals.find(object->name); - - return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); - } - else - { - return Builtin(); - } - } - else if (AstExprGlobal* expr = node->as()) - { - Global* g = globals.find(expr->name); - - return !g || !g->written ? Builtin{AstName(), expr->name} : Builtin(); - } - else - { - return Builtin(); - } - } - - int getBuiltinFunctionId(const Builtin& builtin) - { - if (builtin.empty()) - return -1; - - if (builtin.isGlobal("assert")) - return LBF_ASSERT; - - if (builtin.isGlobal("type")) - return LBF_TYPE; - - if (builtin.isGlobal("typeof")) - return LBF_TYPEOF; - - if (builtin.isGlobal("rawset")) - return LBF_RAWSET; - if (builtin.isGlobal("rawget")) - return LBF_RAWGET; - if (builtin.isGlobal("rawequal")) - return LBF_RAWEQUAL; - - if (builtin.isGlobal("unpack")) - return LBF_TABLE_UNPACK; - - if (builtin.object == "math") - { - if (builtin.method == "abs") - return LBF_MATH_ABS; - if (builtin.method == "acos") - return LBF_MATH_ACOS; - if (builtin.method == "asin") - return LBF_MATH_ASIN; - if (builtin.method == "atan2") - return LBF_MATH_ATAN2; - if (builtin.method == "atan") - return LBF_MATH_ATAN; - if (builtin.method == "ceil") - return LBF_MATH_CEIL; - if (builtin.method == "cosh") - return LBF_MATH_COSH; - if (builtin.method == "cos") - return LBF_MATH_COS; - if (builtin.method == "deg") - return LBF_MATH_DEG; - if (builtin.method == "exp") - return LBF_MATH_EXP; - if (builtin.method == "floor") - return LBF_MATH_FLOOR; - if (builtin.method == "fmod") - return LBF_MATH_FMOD; - if (builtin.method == "frexp") - return LBF_MATH_FREXP; - if (builtin.method == "ldexp") - return LBF_MATH_LDEXP; - if (builtin.method == "log10") - return LBF_MATH_LOG10; - if (builtin.method == "log") - return LBF_MATH_LOG; - if (builtin.method == "max") - return LBF_MATH_MAX; - if (builtin.method == "min") - return LBF_MATH_MIN; - if (builtin.method == "modf") - return LBF_MATH_MODF; - if (builtin.method == "pow") - return LBF_MATH_POW; - if (builtin.method == "rad") - return LBF_MATH_RAD; - if (builtin.method == "sinh") - return LBF_MATH_SINH; - if (builtin.method == "sin") - return LBF_MATH_SIN; - if (builtin.method == "sqrt") - return LBF_MATH_SQRT; - if (builtin.method == "tanh") - return LBF_MATH_TANH; - if (builtin.method == "tan") - return LBF_MATH_TAN; - if (builtin.method == "clamp") - return LBF_MATH_CLAMP; - if (builtin.method == "sign") - return LBF_MATH_SIGN; - if (builtin.method == "round") - return LBF_MATH_ROUND; - } - - if (builtin.object == "bit32") - { - if (builtin.method == "arshift") - return LBF_BIT32_ARSHIFT; - if (builtin.method == "band") - return LBF_BIT32_BAND; - if (builtin.method == "bnot") - return LBF_BIT32_BNOT; - if (builtin.method == "bor") - return LBF_BIT32_BOR; - if (builtin.method == "bxor") - return LBF_BIT32_BXOR; - if (builtin.method == "btest") - return LBF_BIT32_BTEST; - if (builtin.method == "extract") - return LBF_BIT32_EXTRACT; - if (builtin.method == "lrotate") - return LBF_BIT32_LROTATE; - if (builtin.method == "lshift") - return LBF_BIT32_LSHIFT; - if (builtin.method == "replace") - return LBF_BIT32_REPLACE; - if (builtin.method == "rrotate") - return LBF_BIT32_RROTATE; - if (builtin.method == "rshift") - return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz") - return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz") - return LBF_BIT32_COUNTRZ; - } - - if (builtin.object == "string") - { - if (builtin.method == "byte") - return LBF_STRING_BYTE; - if (builtin.method == "char") - return LBF_STRING_CHAR; - if (builtin.method == "len") - return LBF_STRING_LEN; - if (builtin.method == "sub") - return LBF_STRING_SUB; - } - - if (builtin.object == "table") - { - if (builtin.method == "insert") - return LBF_TABLE_INSERT; - if (builtin.method == "unpack") - return LBF_TABLE_UNPACK; - } - - if (options.vectorCtor) - { - if (options.vectorLib) - { - if (builtin.isMethod(options.vectorLib, options.vectorCtor)) - return LBF_VECTOR; - } - else - { - if (builtin.isGlobal(options.vectorCtor)) - return LBF_VECTOR; - } - } - - return -1; - } - BytecodeBuilder& bytecode; CompileOptions options; @@ -3677,8 +2911,9 @@ struct Compiler DenseHashMap functions; DenseHashMap locals; DenseHashMap globals; + DenseHashMap variables; DenseHashMap constants; - DenseHashMap> predictedTableSize; + DenseHashMap tableShapes; unsigned int regTop = 0; unsigned int stackSize = 0; @@ -3699,23 +2934,18 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + assignMutable(compiler.globals, names, options.mutableGlobals); - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; + // this pass analyzes mutability of locals/globals and associates locals with their initial values + trackValues(compiler.globals, compiler.variables, root); - // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written - Compiler::AssignmentVisitor assignmentVisitor(&compiler); - root->visit(&assignmentVisitor); - - // this visitor traverses the AST to analyze constantness of expressions, filling constants[] and Local::constant/Local::init if (options.optimizationLevel >= 1) { - Compiler::ConstantVisitor constantVisitor(&compiler); - root->visit(&constantVisitor); + // this pass analyzes constantness of expressions + foldConstants(compiler.constants, compiler.variables, root); + + // this pass analyzes table assignments to estimate table shapes for initially empty tables + predictTableShapes(compiler.tableShapes, root); } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found @@ -3734,8 +2964,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName for (AstExprFunction* expr : functions) compiler.compileFunction(expr); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), /* self= */ nullptr, - AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); + AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); bytecode.setMainFunction(mainid); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp new file mode 100644 index 000000000..60a7c1692 --- /dev/null +++ b/Compiler/src/ConstantFolding.cpp @@ -0,0 +1,394 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstantFolding.h" + +#include + +namespace Luau +{ +namespace Compile +{ + +static bool constantsEqual(const Constant& la, const Constant& ra) +{ + LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); + + switch (la.type) + { + case Constant::Type_Nil: + return ra.type == Constant::Type_Nil; + + case Constant::Type_Boolean: + return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; + + case Constant::Type_Number: + return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; + + case Constant::Type_String: + return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0; + + default: + LUAU_ASSERT(!"Unexpected constant type in comparison"); + return false; + } +} + +static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) +{ + switch (op) + { + case AstExprUnary::Not: + if (arg.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !arg.isTruthful(); + } + break; + + case AstExprUnary::Minus: + if (arg.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = -arg.valueNumber; + } + break; + + case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.stringLength); + } + break; + + default: + LUAU_ASSERT(!"Unexpected unary operation"); + } +} + +static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) +{ + switch (op) + { + case AstExprBinary::Add: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber + ra.valueNumber; + } + break; + + case AstExprBinary::Sub: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - ra.valueNumber; + } + break; + + case AstExprBinary::Mul: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber * ra.valueNumber; + } + break; + + case AstExprBinary::Div: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber / ra.valueNumber; + } + break; + + case AstExprBinary::Mod: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; + } + break; + + case AstExprBinary::Pow: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = pow(la.valueNumber, ra.valueNumber); + } + break; + + case AstExprBinary::Concat: + break; + + case AstExprBinary::CompareNe: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareEq: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareLt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber < ra.valueNumber; + } + break; + + case AstExprBinary::CompareLe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber <= ra.valueNumber; + } + break; + + case AstExprBinary::CompareGt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber > ra.valueNumber; + } + break; + + case AstExprBinary::CompareGe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber >= ra.valueNumber; + } + break; + + case AstExprBinary::And: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? ra : la; + } + break; + + case AstExprBinary::Or: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? la : ra; + } + break; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + } +} + +struct ConstantVisitor : AstVisitor +{ + DenseHashMap& constants; + DenseHashMap& variables; + + DenseHashMap locals; + + ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables) + : constants(constants) + , variables(variables) + , locals(nullptr) + { + } + + Constant analyze(AstExpr* node) + { + Constant result; + result.type = Constant::Type_Unknown; + + if (AstExprGroup* expr = node->as()) + { + result = analyze(expr->expr); + } + else if (node->is()) + { + result.type = Constant::Type_Nil; + } + else if (AstExprConstantBool* expr = node->as()) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = expr->value; + } + else if (AstExprConstantNumber* expr = node->as()) + { + result.type = Constant::Type_Number; + result.valueNumber = expr->value; + } + else if (AstExprConstantString* expr = node->as()) + { + result.type = Constant::Type_String; + result.valueString = expr->value.data; + result.stringLength = unsigned(expr->value.size); + } + else if (AstExprLocal* expr = node->as()) + { + const Constant* l = locals.find(expr->local); + + if (l) + result = *l; + } + else if (node->is()) + { + // nope + } + else if (node->is()) + { + // nope + } + else if (AstExprCall* expr = node->as()) + { + analyze(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + analyze(expr->args.data[i]); + } + else if (AstExprIndexName* expr = node->as()) + { + analyze(expr->expr); + } + else if (AstExprIndexExpr* expr = node->as()) + { + analyze(expr->expr); + analyze(expr->index); + } + else if (AstExprFunction* expr = node->as()) + { + // this is necessary to propagate constant information in all child functions + expr->body->visit(this); + } + else if (AstExprTable* expr = node->as()) + { + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + analyze(item.key); + + analyze(item.value); + } + } + else if (AstExprUnary* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + if (arg.type != Constant::Type_Unknown) + foldUnary(result, expr->op, arg); + } + else if (AstExprBinary* expr = node->as()) + { + Constant la = analyze(expr->left); + Constant ra = analyze(expr->right); + + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + foldBinary(result, expr->op, la, ra); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + result = arg; + } + else if (AstExprIfElse* expr = node->as()) + { + Constant cond = analyze(expr->condition); + Constant trueExpr = analyze(expr->trueExpr); + Constant falseExpr = analyze(expr->falseExpr); + + if (cond.type != Constant::Type_Unknown) + result = cond.isTruthful() ? trueExpr : falseExpr; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + } + + if (result.type != Constant::Type_Unknown) + constants[node] = result; + + return result; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression + analyze(node); + + return false; + } + + bool visit(AstStatLocal* node) override + { + // all values that align wrt indexing are simple - we just match them 1-1 + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + { + Constant arg = analyze(node->values.data[i]); + + if (arg.type != Constant::Type_Unknown) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]] = arg; + v->constant = true; + } + } + } + + if (node->vars.size > node->values.size) + { + // if we have trailing variables, then depending on whether the last value is capable of returning multiple values + // (aka call or varargs), we either don't know anything about these vars, or we know they're nil + AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; + bool multRet = last && (last->is() || last->is()); + + if (!multRet) + { + for (size_t i = node->values.size; i < node->vars.size; ++i) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]].type = Constant::Type_Nil; + v->constant = true; + } + } + } + } + else + { + // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside + // them + for (size_t i = node->vars.size; i < node->values.size; ++i) + analyze(node->values.data[i]); + } + + return false; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root) +{ + ConstantVisitor visitor{constants, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h new file mode 100644 index 000000000..c0e63539b --- /dev/null +++ b/Compiler/src/ConstantFolding.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +namespace Compile +{ + +struct Constant +{ + enum Type + { + Type_Unknown, + Type_Nil, + Type_Boolean, + Type_Number, + Type_String, + }; + + Type type = Type_Unknown; + unsigned int stringLength = 0; + + union + { + bool valueBoolean; + double valueNumber; + char* valueString = nullptr; // length stored in stringLength + }; + + bool isTruthful() const + { + LUAU_ASSERT(type != Type_Unknown); + return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); + } + + AstArray getString() const + { + LUAU_ASSERT(type == Type_String); + return {valueString, stringLength}; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp new file mode 100644 index 000000000..7d99f2228 --- /dev/null +++ b/Compiler/src/TableShape.cpp @@ -0,0 +1,129 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "TableShape.h" + +namespace Luau +{ +namespace Compile +{ + +static AstExprTable* getTableHint(AstExpr* expr) +{ + // unadorned table literal + if (AstExprTable* table = expr->as()) + return table; + + // setmetatable(table literal, ...) + if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) + if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") + if (AstExprTable* table = call->args.data[0]->as()) + return table; + + return nullptr; +} + +struct ShapeVisitor : AstVisitor +{ + struct Hasher + { + size_t operator()(const std::pair& p) const + { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + DenseHashMap& shapes; + + DenseHashMap tables; + DenseHashSet, Hasher> fields; + + ShapeVisitor(DenseHashMap& shapes) + : shapes(shapes) + , tables(nullptr) + , fields(std::pair()) + { + } + + void assignField(AstExpr* expr, AstName index) + { + if (AstExprLocal* lv = expr->as()) + { + if (AstExprTable** table = tables.find(lv->local)) + { + std::pair field = {*table, index}; + + if (!fields.contains(field)) + { + fields.insert(field); + shapes[*table].hashSize += 1; + } + } + } + } + + void assignField(AstExpr* expr, AstExpr* index) + { + AstExprLocal* lv = expr->as(); + AstExprConstantNumber* number = index->as(); + + if (lv && number) + { + if (AstExprTable** table = tables.find(lv->local)) + { + TableShape& shape = shapes[*table]; + + if (number->value == double(shape.arraySize + 1)) + shape.arraySize += 1; + } + } + } + + void assign(AstExpr* var) + { + if (AstExprIndexName* index = var->as()) + { + assignField(index->expr, index->index); + } + else if (AstExprIndexExpr* index = var->as()) + { + assignField(index->expr, index->index); + } + } + + bool visit(AstStatLocal* node) override + { + // track local -> table association so that we can update table size prediction in assignField + if (node->vars.size == 1 && node->values.size == 1) + if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) + tables[node->vars.data[0]] = table; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root) +{ + ShapeVisitor visitor{shapes}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.h b/Compiler/src/TableShape.h new file mode 100644 index 000000000..f30853a7f --- /dev/null +++ b/Compiler/src/TableShape.h @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +struct TableShape +{ + unsigned int arraySize = 0; + unsigned int hashSize = 0; +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.cpp b/Compiler/src/ValueTracking.cpp new file mode 100644 index 000000000..0bfaf9b3c --- /dev/null +++ b/Compiler/src/ValueTracking.cpp @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ValueTracking.h" + +#include "Luau/Lexer.h" + +namespace Luau +{ +namespace Compile +{ + +struct ValueVisitor : AstVisitor +{ + DenseHashMap& globals; + DenseHashMap& variables; + + ValueVisitor(DenseHashMap& globals, DenseHashMap& variables) + : globals(globals) + , variables(variables) + { + } + + void assign(AstExpr* var) + { + if (AstExprLocal* lv = var->as()) + { + variables[lv->local].written = true; + } + else if (AstExprGlobal* gv = var->as()) + { + globals[gv->name] = Global::Written; + } + else + { + // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 + var->visit(this); + } + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + variables[node->vars.data[i]].init = node->values.data[i]; + + for (size_t i = node->values.size; i < node->vars.size; ++i) + variables[node->vars.data[i]].init = nullptr; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + node->value->visit(this); + + return false; + } + + bool visit(AstStatLocalFunction* node) override + { + variables[node->name].init = node->func; + + return true; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals) +{ + if (AstName name = names.get("_G"); name.value) + globals[name] = Global::Mutable; + + if (mutableGlobals) + for (const char** ptr = mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + globals[name] = Global::Mutable; +} + +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root) +{ + ValueVisitor visitor{globals, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.h b/Compiler/src/ValueTracking.h new file mode 100644 index 000000000..fc74c84aa --- /dev/null +++ b/Compiler/src/ValueTracking.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +class AstNameTable; +} + +namespace Luau +{ +namespace Compile +{ + +enum class Global +{ + Default = 0, + Mutable, // builtin that has contents unknown at compile time, blocks GETIMPORT for chains + Written, // written in the code which means we can't reason about the value +}; + +struct Variable +{ + AstExpr* init = nullptr; // initial value of the variable; filled by trackValues + bool written = false; // is the variable ever assigned to? filled by trackValues + bool constant = false; // is the variable's value a compile-time constant? filled by constantFold +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals); +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root); + +inline Global getGlobalState(const DenseHashMap& globals, AstName name) +{ + const Global* it = globals.find(name); + + return it ? *it : Global::Default; +} + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 5dd486aaa..bafe75948 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -29,7 +29,15 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/Builtins.cpp + Compiler/src/ConstantFolding.cpp + Compiler/src/TableShape.cpp + Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp + Compiler/src/Builtins.h + Compiler/src/ConstantFolding.h + Compiler/src/TableShape.h + Compiler/src/ValueTracking.h ) # Luau.Analysis Sources diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index eb47971a8..3cce7665d 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,7 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) /* @@ -545,11 +544,8 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e if (!oldactive) resetbit(L->stackstate, THREAD_ACTIVEBIT); - if (FFlag::LuauCcallRestoreFix) - { - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. - L->nCcalls = oldnCcalls; - } + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) @@ -564,10 +560,6 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - if (!FFlag::LuauCcallRestoreFix) - { - L->nCcalls = oldnCcalls; - } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 648785697..4178eda40 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,6 +6,8 @@ #include "lmem.h" #include "lgc.h" +LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) + Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); @@ -113,14 +115,16 @@ void luaF_freeupval(lua_State* L, UpVal* uv) void luaF_close(lua_State* L, StkId level) { UpVal* uv; - global_State* g = L->global; + global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); L->openupval = uv->next; /* remove from `open' list */ - if (isdead(g, o)) + if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + { luaF_freeupval(L, uv); /* free upvalue */ + } else { unlinkupval(uv); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 0b3054ae4..74a8aa8a1 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) - /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1406,20 +1404,10 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - if (FFlag::LuauStrPackUBCastFix) - { - long long n = (long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, (unsigned long long)n, h.islittle, size, 0); - } - else - { - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); - } + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); break; } case Kfloat: diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 6d40406cd..5d162ab99 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -111,15 +111,10 @@ end -- multiplies two matrices function MMulti(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do - M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; j = j + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; end - - i = i + 1 end return M; end @@ -127,28 +122,27 @@ end -- multiplies matrix with vector function VMulti(M, V) local Vect = {}; - local i = 1; - while i <= 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; i = i + 1 end + for i = 1,4 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; + end return Vect; end function VMulti2(M, V) local Vect = {}; - local i = 1; - while i < 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; i = i + 1 end + for i = 1,3 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; + end return Vect; end -- add to matrices function MAdd(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do M[i][j] = M1[i][j] + M2[i][j]; j = j + 1 end - - i = i + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][j] + M2[i][j]; + end end return M; end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 210db7eea..211e1be1f 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1938,7 +1938,6 @@ return target(b@1 TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function bar(a: number) return -a end @@ -1954,7 +1953,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function foo() return 1 end @@ -2538,10 +2536,6 @@ TEST_CASE("autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { check(R"( local temp = false local even = true; @@ -2614,7 +2608,6 @@ a = if temp then even elseif true then temp else e@9 CHECK(ac.entryMap.count("then") == 0); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); - } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") @@ -2681,4 +2674,58 @@ local r4 = t:bar1(@4) CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +{ + ScopedFastFlag flag("LuauMissingFollowACMetatables", true); + check(R"( +--!strict +local Class = {} +Class.__index = Class +type Class = typeof(setmetatable({} :: { x: number }, Class)) +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end +function Class.getx(self: Class) + return self.x +end +function test() + local c = Class.new(42) + local n = c:@1 + print(n) +end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("getx")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 95811b3f4..8eed953fd 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,9 +603,9 @@ RETURN R0 1 )"); } -TEST_CASE("EmptyTableHashSizePredictionOptimization") +TEST_CASE("TableSizePredictionBasic") { - const char* hashSizeSource = R"( + CHECK_EQ("\n" + compileFunction0(R"( local t = {} t.a = 1 t.b = 1 @@ -616,36 +616,8 @@ t.f = 1 t.g = 1 t.h = 1 t.i = 1 -)"; - - const char* hashSizeSource2 = R"( -local t = {} -t.x = 1 -t.x = 2 -t.x = 3 -t.x = 4 -t.x = 5 -t.x = 6 -t.x = 7 -t.x = 8 -t.x = 9 -)"; - - const char* arraySizeSource = R"( -local t = {} -t[1] = 1 -t[2] = 1 -t[3] = 1 -t[4] = 1 -t[5] = 1 -t[6] = 1 -t[7] = 1 -t[8] = 1 -t[9] = 1 -t[10] = 1 -)"; - - CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"( +)"), + R"( NEWTABLE R0 16 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -668,7 +640,19 @@ SETTABLEKS R1 R0 K8 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t.x = 1 +t.x = 2 +t.x = 3 +t.x = 4 +t.x = 5 +t.x = 6 +t.x = 7 +t.x = 8 +t.x = 9 +)"), + R"( NEWTABLE R0 1 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -691,7 +675,20 @@ SETTABLEKS R1 R0 K0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t[1] = 1 +t[2] = 1 +t[3] = 1 +t[4] = 1 +t[5] = 1 +t[6] = 1 +t[7] = 1 +t[8] = 1 +t[9] = 1 +t[10] = 1 +)"), + R"( NEWTABLE R0 0 10 LOADN R1 1 SETTABLEN R1 R0 1 @@ -717,6 +714,27 @@ RETURN R0 0 )"); } +TEST_CASE("TableSizePredictionObject") +{ + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +t.field = 1 +function t:getfield() + return self.field +end +return t +)", + 1), + R"( +NEWTABLE R0 2 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +DUPCLOSURE R1 K1 +SETTABLEKS R1 R0 K2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionSetMetatable") { CHECK_EQ("\n" + compileFunction0(R"( @@ -1031,9 +1049,6 @@ RETURN R0 1 TEST_CASE("IfElseExpression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - // codegen for a true constant condition CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( LOADN R0 10 @@ -3058,7 +3073,7 @@ RETURN R0 0 // table variants (indexed by string, number, variable) CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( -NEWTABLE R0 1 0 +NEWTABLE R0 0 0 GETTABLEKS R1 R0 K0 ADDK R1 R1 K1 SETTABLEKS R1 R0 K0 @@ -3066,7 +3081,7 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( -NEWTABLE R0 0 1 +NEWTABLE R0 0 0 GETTABLEN R1 R0 1 ADDK R1 R1 K0 SETTABLEN R1 R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 663b329ee..5222af33e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -366,15 +366,11 @@ TEST_CASE("PCall") TEST_CASE("Pack") { - ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; - runConformance("tpack.lua"); } TEST_CASE("Vector") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - lua_CompileOptions copts = {}; copts.optimizationLevel = 1; copts.debugLevel = 1; @@ -861,15 +857,11 @@ TEST_CASE("ExceptionObject") TEST_CASE("IfElseExpression") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("ifelseexpr.lua"); } TEST_CASE("TagMethodError") { - ScopedFastFlag sff{"LuauCcallRestoreFix", true}; - runConformance("tmerror.lua", [](lua_State* L) { auto* cb = lua_callbacks(L); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ca4281a0b..c74bfa272 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -191,7 +191,7 @@ ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& par return result; } -ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) +ParseResult Fixture::matchParseError(const std::string& source, const std::string& message, std::optional location) { ParseOptions options; options.allowDeclarationSyntax = true; @@ -203,6 +203,9 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin CHECK_EQ(result.errors.front().getMessage(), message); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + return result; } diff --git a/tests/Fixture.h b/tests/Fixture.h index e01632eab..ab852ef6d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -106,7 +106,7 @@ struct Fixture /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); - ParseResult matchParseError(const std::string& source, const std::string& message); + ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 5abcb09ac..e91356515 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1255,7 +1255,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " @@ -1266,7 +1265,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " @@ -1276,7 +1274,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements" TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " @@ -1286,7 +1283,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1 TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError( @@ -1962,6 +1958,37 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + AstStat* stat = parse(R"( +type A = {} +type B = {} +type C = {} +type D = {} +type E = {} +type F = (T...) -> U... +type G = (U...) -> T... + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); + matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); + matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2455,10 +2482,19 @@ do end CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true}; + + ParseResult result = tryParse(R"( +type Y = (T...) -> U... + )"); + CHECK_EQ(1, result.errors.size()); +} +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ { AstStat* stat = parse("return if true then 1 else 2"); @@ -2524,9 +2560,4 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } -TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") -{ - matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); -} - TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 80a258f5b..445ee5329 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -338,6 +338,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -353,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -500,6 +502,24 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") +{ + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( + local function f(a: number, b: string) end + local function test(...: T...): U... + f(...) + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("test(...: T...): U...", toStringNamedFunction("test", *ftv)); +} + TEST_CASE("toStringNamedFunction_unit_f") { TypePackVar empty{TypePack{}}; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 47c3883c1..ac5be859b 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -421,8 +421,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") { - ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); - std::string code = "local a = if 1 then 2 else 3"; CHECK_EQ(code, transpile(code).code); @@ -641,4 +639,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string") CHECK_EQ("'hello'", toString(expr)); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + std::string code = R"( +type Packed = (T, U, V...)->(W...) +local a: Packed + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 275782b30..86165814c 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -497,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") CHECK(arrayTable->indexer); CHECK(isInArena(array.type, mod.interfaceTypes)); - CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); + CHECK_EQ(array.typeParams[0].ty, arrayTable->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 503b613f4..d76b920bb 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1031,9 +1031,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if v then v else tostring(v) @@ -1048,9 +1045,6 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if not v then tostring(v) else v @@ -1065,9 +1059,6 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function returnOne(x) return 1 @@ -1119,6 +1110,25 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = { x: { y: number }? } + + local function f(t: T?) + if t and t.x then + local foo = t.x.y + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({5, 32}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 68dc1b4fa..94cfb6437 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -360,6 +360,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { ScopedFastFlag sffs[] = { {"LuauParseSingletonTypes", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -369,7 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", toString(result.errors[0])); } @@ -423,4 +424,27 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + {"LuauIfElseExpectedType2", true}, + {"LuauIfElseBranchTypeUnion", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag = 'dog', dogfood = 'other' } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 80f40407e..27cda1463 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { - CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); + CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); - CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); + CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}})); } TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") @@ -1155,7 +1155,8 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_mu TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t.m() end )"); @@ -1165,13 +1166,38 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t:m() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t.m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + // This unit test could be flaky if the fix has regressed. TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") { @@ -1439,8 +1465,13 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") CHECK_EQ("{| |}", toString(mp->subType)); } -TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") +TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; + CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1448,6 +1479,25 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + // Should t now have an indexer? + // It would if the assignment to rt was correctly typed. + CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer") +{ + CheckResult result = check(R"( + type StringToStringMap = { [string]: string } + function mkrt() return { ["foo"] = 1 } end + local rt: StringToStringMap = mkrt() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); @@ -1467,7 +1517,10 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1480,7 +1533,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); + CHECK_EQ("{ a: number }", toString(tm->givenType, o)); } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") @@ -1536,8 +1589,11 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {{x = 1, y = 2, z = 3}} - local vec1 = {{x = 1}} + function mkvec3() return {x = 1, y = 2, z = 3} end + function mkvec1() return {x = 1} end + + local vec3 = {mkvec3()} + local vec1 = {mkvec1()} vec1 = vec3 )"); @@ -1620,7 +1676,8 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_ { CheckResult result = check(R"( --!strict - local A = {"value"} + function mkA() return {"value"} end + local A = mkA() A.B = "Hello" )"); @@ -1668,7 +1725,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") --!strict local function f() - local t = { x = 1 } + local function mkt() return { x = 1 } end + local t = mkt() function t.a() end function t.b() end @@ -1995,7 +2053,10 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2010,7 +2071,7 @@ local c2: typeof(a2) = b2 LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' caused by: - Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); @@ -2018,7 +2079,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); } @@ -2026,7 +2087,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } @@ -2059,6 +2120,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -2077,7 +2139,7 @@ local y: number = tmp.p.y LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' caused by: - Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); + Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") @@ -2103,7 +2165,10 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { - ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + ScopedFastFlag sff[]{ + {"LuauFixRecursiveMetatableCall", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local b @@ -2112,7 +2177,7 @@ b() )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f70f3b1c8..7a056af52 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4525,7 +4525,9 @@ f(function(x) print(x) end) } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4549,7 +4551,7 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") @@ -4689,6 +4691,18 @@ a = setmetatable(a, { __call = function(x) end }) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") +{ + ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true}; + + CheckResult result = check(R"( +local function f(a: (number, number) -> number) return a(1, 3) end +f(((function(a, b) return a + b end))) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "refine_and_or") { CheckResult result = check(R"( @@ -4743,46 +4757,75 @@ local c: X TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - CheckResult result = check(R"(local a = if true then "true" else "false")"); - LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"(local a = if true then "true" else "false")"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + // Test expression containing elseif + CheckResult result = check(R"( +local a = if false then "a" elseif false then "b" else "c" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") +{ + ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true}; { - // Test expression containing elseif - CheckResult result = check(R"( -local a = if false then "a" elseif false then "b" else "c" - )"); + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); } } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - { - CheckResult result = check(R"(local a = if true then "true" else 42)"); - // We currently require both true/false expressions to unify to the same type. However, we do intend to lift - // this restriction in the future. - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"( +type X = {number | string} +local a: X = if true then {"1", 2, 3} else {4, 5, 6} +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "{number | string}"); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true }; + + CheckResult result = check(R"( +local a: number? = if true then 1 else nil +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + + CheckResult result = check(R"( +local function times(n: any, f: () -> T) + local result: {T} = {} + local res = f() + table.insert(result, if true then res else n) + return result +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_error_addition") @@ -5039,4 +5082,51 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +{ + ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true}; + + CheckResult result = check(R"( +local function getIt() + local y + y = setmetatable({}, y) + return y +end +local a = getIt() +local b = getIt() +local c = a or b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5d37b032a..d4878d149 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -621,4 +621,328 @@ type Other = Packed CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = 2, b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y = { a: T } + +local a: Y = { a = 2 } +local b: Y<> = { a = "s" } +local c: Y = { a = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = "h", b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y string> = { a: T, b: U } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y string>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U, c: V } + +local a: Y +local b: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: (U...) -> T } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U..., b: (T...) -> V... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T, U, V...) -> W... } +local a: Y +local b: Y +local c: Y +local d: Y ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); + CHECK_EQ(toString(requireType("d")), "Y ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T } +local a: Y = { a = 2 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T) -> U... } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Y' expects at least 1 type argument, but none are specified"); + + result = check(R"( +type Packed = (T) -> T +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Y = { a: T } +local a: Y + )"); + + LUAU_REQUIRE_ERRORS(result); + + result = check(R"( +type Y = { a: T } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + fileResolver.source["Module/Types"] = R"( +export type A = { a: T, b: U } +export type B = { a: T, b: U } +export type C string> = { a: T, b: U } +export type D = { a: T, b: U, c: V } +export type E = { a: (T...) -> () } +export type F = { a: T, b: (U...) -> T } +export type G = { b: (U...) -> T... } +export type H = { b: (T...) -> T... } +return {} + )"; + + CheckResult resultTypes = frontend.check("Module/Types"); + LUAU_REQUIRE_NO_ERRORS(resultTypes); + + fileResolver.source["Module/Users"] = R"( +local Types = require(script.Parent.Types) + +local a: Types.A +local b: Types.B +local c: Types.C +local d: Types.D +local e: Types.E<> +local eVoid: Types.E<()> +local f: Types.F +local g: Types.G<...number> +local h: Types.H<> + )"; + + CheckResult resultUsers = frontend.check("Module/Users"); + LUAU_REQUIRE_NO_ERRORS(resultUsers); + + CHECK_EQ(toString(requireType("Module/Users", "a")), "A"); + CHECK_EQ(toString(requireType("Module/Users", "b")), "B"); + CHECK_EQ(toString(requireType("Module/Users", "c")), "C string>"); + CHECK_EQ(toString(requireType("Module/Users", "d")), "D"); + CHECK_EQ(toString(requireType("Module/Users", "e")), "E"); + CHECK_EQ(toString(requireType("Module/Users", "eVoid")), "E<>"); + CHECK_EQ(toString(requireType("Module/Users", "f")), "F"); + CHECK_EQ(toString(requireType("Module/Users", "g")), "G<...number, ()>"); + CHECK_EQ(toString(requireType("Module/Users", "h")), "H<>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = (T...) -> number +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "(...string) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type A = (T, V...) -> (U, W...) +type B = A +type C = A + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("B"), {true}), "(string, ...any) -> (number, ...any)"); + CHECK_EQ(toString(*lookupType("C"), {true}), "(string, boolean) -> (number, boolean)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type F ()> = (K) -> V +type R = { m: F } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") +{ + ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +local a: () -> (number, ...string) +local b: () -> (number, ...boolean) +a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' +caused by: + Type 'boolean' could not be converted into 'string')"); +} + TEST_SUITE_END();