Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Edge Case w/ Intersection Type & Math Operation Overloads #1009

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions Analysis/src/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauNormalizeBlockedTypes)
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAG(LuauIntersectedBinopOverloadFix)

namespace Luau
{
Expand Down Expand Up @@ -263,6 +264,20 @@ std::optional<TypeId> getMetatable(TypeId type, NotNull<BuiltinTypes> builtinTyp
{
type = follow(type);

if (FFlag::LuauIntersectedBinopOverloadFix)
{
if (const IntersectionType* itv = get<IntersectionType>(type))
{
for (TypeId part : itv->parts)
{
auto partMT = getMetatable(part, builtinTypes);
if (partMT != std::nullopt)
return partMT;
}
Comment on lines +271 to +276
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

This is a bit scary because, if we are going to admit the possibility that an intersection might have a metatable, we should also consider the possibility that the intersection has many metatables.

I think it would be acceptable if this function were to return std::nullopt when given an intersection that contains multiple tables with disparate metatables. It's not perfect but it also moves us to a place that's strictly better than where we are now.

The second complication that's likely to rear its head is a cyclic intersection. We try to avoid creating these, but we need to handle them without overflowing the stack when they do crop up.

The usual way we make functions like this resilient in the face of cycles is to write a helper overload that accepts a mutable seen set as an argument. In this case, any intersection that's part of a cycle can safely be said to have no metatable.

std::optional<TypeId> getMetatable(TypeId type, NotNull<BuiltinTypes> builtinTypes, std::set<TypeId>& seen);

std::optional<TypeId> getMetatable(TypeId type, NotNull<BuiltinTypes> builtinTypes)
{
    if (FFlag::LuauIntersectedBinopOverloadFix)
    {
        std::set<TypeId> seen;
        return getMetatable(type, builtinTypes, seen);
    }
    // ...
}

std::optional<TypeId> getMetatable(TypeId type, NotNull<BuiltinTypes> builtinTypes, std::set<TypeId>& seen)
{
    if (seen.count(type))
        return std::nullopt;
    seen.insert(type);
    // ... all the logic from the original getMetatable here ...
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate the review! Thanks for the note on cyclic intersections and other edge cases. I definitely see the issue with intersections of multiple metatable-containing types and how that gets handled. That would need some refactoring outside of the scope of this PR here.

I'll look into the suggestions you mentioned in my free time, hope it can at least be an incremental step towards having fully sound intersection types and make it into production if it's all stable and doesn't make anything worse per se. There was a related issue #983 involving the interaction between metatables and intersection types.

return std::nullopt;
}
}

if (const MetatableType* mtType = get<MetatableType>(type))
return mtType->metatable;
else if (const ClassType* classType = get<ClassType>(type))
Expand Down
10 changes: 10 additions & 0 deletions Analysis/src/TypeInfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@ LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAG(LuauParseDeclareClassIndexer)
LUAU_FASTFLAGVARIABLE(LuauIntersectedBinopOverloadFix, false)

namespace Luau
{

static bool typeCouldHaveMetatable(TypeId ty)
{
if (FFlag::LuauIntersectedBinopOverloadFix) {
if (auto itv = get<IntersectionType>(follow(ty)))
{
for (TypeId part : itv->parts)
if (typeCouldHaveMetatable(part))
return true;
return false;
}
}
return get<TableType>(follow(ty)) || get<ClassType>(follow(ty)) || get<MetatableType>(follow(ty));
}

Expand Down
48 changes: 48 additions & 0 deletions tests/TypeInfer.operators.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,54 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable")
CHECK("Type 'number' could not be converted into 'V2'" == toString(result.errors[0]));
}

TEST_CASE_FIXTURE(BuiltinsFixture, "overloaded_op_accept_structured_subtype")
{
ScopedFastFlag sff{"LuauIntersectedBinopOverloadFix", true};
CheckResult result = check(R"(
--!strict
type BaseType = typeof(setmetatable(
{},
({} :: any) :: {__add: (BaseType, BaseType) -> BaseType})
)
type SubType = BaseType & {extraField: string}

local function add1(x: BaseType, y: BaseType): BaseType
return x + y
end
local function add2(x: SubType, y: BaseType): BaseType
return x + y
end
local function add3(x: BaseType, y: SubType): BaseType
return x + y
end
local function add4(x: SubType, y: SubType): BaseType
return x + y
end
)");

LUAU_REQUIRE_ERROR_COUNT(0, result);
}

TEST_CASE_FIXTURE(BuiltinsFixture, "overloaded_op_disallow_unrelated_type")
{
ScopedFastFlag sff{"LuauIntersectedBinopOverloadFix", true};
CheckResult result = check(R"(
--!strict
type BaseType = typeof(setmetatable(
{},
({} :: any) :: {__mul: (BaseType, BaseType) -> BaseType})
)
type Unrelated = {extraField: string}

local function add(x: BaseType, y: Unrelated)
return x * y
end
)");

LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(toString(result.errors[0]) == "Type 'Unrelated' could not be converted into 'BaseType'");
}

TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions")
{
CheckResult result = check(R"(
Expand Down