Skip to content

Commit

Permalink
Simplify how types are complex types are narrowed.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Apr 15, 2024
1 parent c1e07fe commit 0292db6
Showing 1 changed file with 170 additions and 164 deletions.
334 changes: 170 additions & 164 deletions edb/edgeql/compiler/schemactx.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,220 +503,226 @@ class TypeIntersectionResult(NamedTuple):
is_subtype: bool = False


def apply_intersection(
left: s_types.Type, right: s_types.Type, *, ctx: context.ContextLevel
) -> TypeIntersectionResult:
"""Compute an intersection of two types: *left* and *right*.
Returns:
A :class:`~TypeIntersectionResult` named tuple containing the
result intersection type, whether the type system considers
the intersection empty and whether *left* is related to *right*
(i.e either is a subtype of another).
"""
class NormalizedType(NamedTuple):
# A union of intersections of simple types
# All type expressions can ultimately be reduced to one of these

if left.issubclass(ctx.env.schema, right):
# The intersection type is a proper *superclass*
# of the argument, then this is, effectively, a NOP.
return TypeIntersectionResult(stype=left)
union_of_intersections: list[list[s_types.Type]]

def intersect_intersections(
left_intersection: list[s_types.Type],
right_intersection: list[s_types.Type],
) -> list[s_types.Type]:
current_left_intersection: list[s_types.Type] = left_intersection
result: list[s_types.Type] = []

for right_subtype in right_intersection:
superclass_found = False
subclass_found = False
result = []

for left_subtype in current_left_intersection:
if right_subtype.issubclass(ctx.env.schema, left_subtype):
# Replace all superclasses with right subtype
if not superclass_found:
result.append(right_subtype)
superclass_found = True
else:
if left_subtype.issubclass(ctx.env.schema, right_subtype):
# No need to add right subtype if a subclass was found
subclass_found = True
def as_type(self, *, ctx: context.ContextLevel) -> s_types.Type:
result_intersections = [
(
intersection[0]
if len(intersection) == 1 else
get_intersection_type(intersection, ctx=ctx)
)
for intersection in self.union_of_intersections
]

result.append(left_subtype)
result_union = (
result_intersections[0]
if len(result_intersections) == 1 else
get_union_type(result_intersections, ctx=ctx)
)

if not superclass_found and not subclass_found:
result.append(right_subtype)
return result_union

current_left_intersection = result
@staticmethod
def create(
type: s_types.Type,
*,
ctx: context.ContextLevel,
) -> NormalizedType:

return result
source_union: list[list[s_types.Type]] = (
NormalizedType._create(type, ctx=ctx)
)
simplified_union: list[list[s_types.Type]] = (
NormalizedType._simplify_intersections(source_union, ctx=ctx)
)
return NormalizedType(
union_of_intersections=simplified_union
)

def is_intersection_subtype(
left_intersection: list[s_types.Type],
right_intersection: list[s_types.Type],
) -> bool:
return all(
any(
right_subtype.issubclass(ctx.env.schema, left_subtype)
for right_subtype in right_intersection
@staticmethod
def intersection_of_normalized(
left_disjunction: NormalizedType,
right_disjunction: NormalizedType,
*,
ctx: context.ContextLevel,
) -> NormalizedType:

source_union: list[list[s_types.Type]] = [
NormalizedType._intersection_of_intersections(
leftt_intersection, right_intersection, ctx=ctx
)
for leftt_intersection in (
left_disjunction.union_of_intersections
)
for right_intersection in (
right_disjunction.union_of_intersections
)
for left_subtype in left_intersection
]
simplified_union: list[list[s_types.Type]] = (
NormalizedType._simplify_intersections(source_union, ctx=ctx)
)

def narrow_complex_type(
left_subtype: s_types.Type,
right_subtype: s_types.Type,
) -> s_types.Type:
union_of_intersections: list[list[s_types.Type]] = narrow_complex_type_(
left_subtype, right_subtype
return NormalizedType(
union_of_intersections=simplified_union
)

@staticmethod
def _simplify_intersections(
source_union: list[list[s_types.Type]],
*,
ctx: context.ContextLevel,
) -> list[list[s_types.Type]]:

# filter out any superclasses
simplified: list[list[s_types.Type]] = []
for source_intersection in union_of_intersections:
simplified_union: list[list[s_types.Type]] = []
for source_intersection in source_union:
if any(
is_intersection_subtype(
simplified_intersection, source_intersection
NormalizedType._is_sub_intersection(
simplified_intersection, source_intersection, ctx=ctx
)
for simplified_intersection in simplified
for simplified_intersection in simplified_union
):
# skip subclass of previously seen
pass
else:
# filter out any previous subclasses
simplified = [
simplified_union = [
simplified_intersection
for simplified_intersection in simplified
if not is_intersection_subtype(
source_intersection, simplified_intersection
for simplified_intersection in simplified_union
if not NormalizedType._is_sub_intersection(
source_intersection, simplified_intersection, ctx=ctx
)
] + [source_intersection]

result_intersections = [
(
intersection[0]
if len(intersection) == 1 else
get_intersection_type(intersection, ctx=ctx)
)
for intersection in simplified
]
return simplified_union

result_union = (
result_intersections[0]
if len(result_intersections) == 1 else
get_union_type(result_intersections, ctx=ctx)
)
@staticmethod
def _create(
type: s_types.Type,
*,
ctx: context.ContextLevel,
) -> list[list[s_types.Type]]:

return result_union
result_union: list[list[s_types.Type]] = []

def narrow_complex_type_(
left_subtype: s_types.Type,
right_subtype: s_types.Type,
) -> list[list[s_types.Type]]:
# Expand the left subtype into underlying simple types
# The simple left types will be narrowed then re-combined
# Returns a union of intersections

result: list[list[s_types.Type]] = []
if left_sub_union := left_subtype.get_union_of(ctx.env.schema):
# Union simply collects all results
for left_component in left_sub_union.objects(ctx.env.schema):
result += narrow_complex_type_(
left_component, right_subtype
if sub_union := type.get_union_of(ctx.env.schema):
# simply expand all sub-intersections
for component in sub_union.objects(ctx.env.schema):
result_union += (
NormalizedType._create(component, ctx=ctx)
)
return result
return result_union

elif left_sub_intersection := (
left_subtype.get_intersection_of(ctx.env.schema)
):
elif sub_intersection := type.get_intersection_of(ctx.env.schema):
# Intersection needs to produce the cross intersection of all
# results
component_results: list[list[list[s_types.Type]]] = [
narrow_complex_type_(left_component, right_subtype)
for left_component in left_sub_intersection.objects(
ctx.env.schema
)
component_unions: list[list[list[s_types.Type]]] = [
NormalizedType._create(component, ctx=ctx)
for component in sub_intersection.objects(ctx.env.schema)
]

result = component_results[0]
for component_union in component_results[1:]:
result = [
intersect_intersections(
result_intersection, component_intersection
result_union = component_unions[0]
for component_union in component_unions[1:]:
result_union = [
NormalizedType._intersection_of_intersections(
result_intersection, component_intersection, ctx=ctx
)
for result_intersection in result
for result_intersection in result_union
for component_intersection in component_union
]
return result
return result_union

elif left_subtype.is_view(ctx.env.schema):
return narrow_complex_type_(
left_subtype.peel_view(ctx.env.schema), right_subtype
elif type.is_view(ctx.env.schema):
return NormalizedType._create(
type.peel_view(ctx.env.schema), ctx=ctx
)

# Now the left subtype is a fully unwrapped type.
return narrow_flattened_intersection(
[left_subtype], right_subtype
)
# Now the left type is a simple type.
return [[type]]

def narrow_flattened_intersection(
@staticmethod
def _intersection_of_intersections(
left_intersection: list[s_types.Type],
right_subtype: s_types.Type,
) -> list[list[s_types.Type]]:
# Narrows an intersection of types.
# Returns a union of intersections
right_intersection: list[s_types.Type],
*,
ctx: context.ContextLevel,
) -> list[s_types.Type]:

# If any left type is a subclass of right, no narrowing needed.
if any(
left_subtype.issubclass(ctx.env.schema, right_subtype)
for left_subtype in left_intersection
):
return [left_intersection]

result: list[list[s_types.Type]]
if right_union := right_subtype.get_union_of(ctx.env.schema):
# When narrowing on a union, combine the results of narrowing
# on each component of the union

result = [
narrowed
for right_component in right_union.objects(ctx.env.schema)
for narrowed in narrow_flattened_intersection(
left_intersection, right_component
)
]
current_left_intersection: list[s_types.Type] = left_intersection
result: list[s_types.Type] = []

return result
for right_type in right_intersection:
superclass_found = False
subclass_found = False
result = []

elif right_intersection := (
right_subtype.get_intersection_of(ctx.env.schema)
):
# When narrowing on an intersection, repeatedly narrow the left
# intersection (and results) on each component of the intersection
for left_type in current_left_intersection:
if right_type.issubclass(
ctx.env.schema, left_type
):
# Replace all superclasses with right type
if not superclass_found:
result.append(right_type)
superclass_found = True
else:
if left_type.issubclass(
ctx.env.schema, right_type
):
# No need to add right type if a subclass was found
subclass_found = True

result = [left_intersection]
result.append(left_type)

for right_component in right_intersection.objects(ctx.env.schema):
result = [
narrowed
for t in result
for narrowed in narrow_flattened_intersection(
t, right_component
)
]
if not superclass_found and not subclass_found:
result.append(right_type)

current_left_intersection = result

return result
return result

elif right_subtype.is_view(ctx.env.schema):
return narrow_flattened_intersection(
left_intersection, right_subtype.peel_view(ctx.env.schema)
@staticmethod
def _is_sub_intersection(
left_intersection: list[s_types.Type],
right_intersection: list[s_types.Type],
*,
ctx: context.ContextLevel,
) -> bool:

return all(
any(
right_type.issubclass(ctx.env.schema, left_type)
for right_type in right_intersection
)
for left_type in left_intersection
)


def apply_intersection(
left: s_types.Type, right: s_types.Type, *, ctx: context.ContextLevel
) -> TypeIntersectionResult:
"""Compute an intersection of two types: *left* and *right*.
# Now the right subtype is a fully unwrapped type.
return [intersect_intersections(left_intersection, [right_subtype])]
Returns:
A :class:`~TypeIntersectionResult` named tuple containing the
result intersection type, whether the type system considers
the intersection empty and whether *left* is related to *right*
(i.e either is a subtype of another).
"""

if left.issubclass(ctx.env.schema, right):
# The intersection type is a proper *superclass*
# of the argument, then this is, effectively, a NOP.
return TypeIntersectionResult(stype=left)

int_type = narrow_complex_type(left, right)
left_normalized = NormalizedType.create(left, ctx=ctx)
right_normalized = NormalizedType.create(right, ctx=ctx)
int_type = NormalizedType.intersection_of_normalized(
left_normalized, right_normalized, ctx=ctx
).as_type(ctx=ctx)
is_subtype = right.issubclass(ctx.env.schema, left)

return TypeIntersectionResult(
Expand Down

0 comments on commit 0292db6

Please sign in to comment.