Skip to content

Commit

Permalink
Implement method to invert type intersection.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Apr 15, 2024
1 parent 0292db6 commit 12cae6b
Showing 1 changed file with 241 additions and 79 deletions.
320 changes: 241 additions & 79 deletions edb/edgeql/compiler/schemactx.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,115 +535,112 @@ def create(
) -> NormalizedType:

source_union: list[list[s_types.Type]] = (
NormalizedType._create(type, ctx=ctx)
)
simplified_union: list[list[s_types.Type]] = (
NormalizedType._simplify_intersections(source_union, ctx=ctx)
)
return NormalizedType(
union_of_intersections=simplified_union
)

@staticmethod
def intersection_of_normalized(
left_disjunction: NormalizedType,
right_disjunction: NormalizedType,
*,
ctx: context.ContextLevel,
) -> NormalizedType:

source_union: list[list[s_types.Type]] = [
NormalizedType._intersection_of_intersections(
leftt_intersection, right_intersection, ctx=ctx
)
for leftt_intersection in (
left_disjunction.union_of_intersections
NormalizedType._create(
type,
ctx=ctx,
expand_opaque_union=True
)
for right_intersection in (
right_disjunction.union_of_intersections
)
]
)
simplified_union: list[list[s_types.Type]] = (
NormalizedType._simplify_intersections(source_union, ctx=ctx)
NormalizedType._simplify_union(source_union, ctx=ctx)
)
return NormalizedType(
union_of_intersections=simplified_union
)

@staticmethod
def _simplify_intersections(
source_union: list[list[s_types.Type]],
*,
ctx: context.ContextLevel,
) -> list[list[s_types.Type]]:

# filter out any superclasses
simplified_union: list[list[s_types.Type]] = []
for source_intersection in source_union:
if any(
NormalizedType._is_sub_intersection(
simplified_intersection, source_intersection, ctx=ctx
)
for simplified_intersection in simplified_union
):
# skip subclass of previously seen
pass
else:
# filter out any previous subclasses
simplified_union = [
simplified_intersection
for simplified_intersection in simplified_union
if not NormalizedType._is_sub_intersection(
source_intersection, simplified_intersection, ctx=ctx
)
] + [source_intersection]

return simplified_union

@staticmethod
def _create(
type: s_types.Type,
*,
ctx: context.ContextLevel,
expand_opaque_union: bool,
) -> list[list[s_types.Type]]:

result_union: list[list[s_types.Type]] = []

if sub_union := type.get_union_of(ctx.env.schema):
if (
(
not type.get_is_opaque_union(ctx.env.schema)
or expand_opaque_union or True
)
and (sub_union := type.get_union_of(ctx.env.schema))
):
# simply expand all sub-intersections
result_union: list[list[s_types.Type]] = []

for component in sub_union.objects(ctx.env.schema):
result_union += (
NormalizedType._create(component, ctx=ctx)
NormalizedType._create(
component,
ctx=ctx,
expand_opaque_union=False,
)
)
return result_union

elif sub_intersection := type.get_intersection_of(ctx.env.schema):
# Intersection needs to produce the cross intersection of all
# results
component_unions: list[list[list[s_types.Type]]] = [
NormalizedType._create(component, ctx=ctx)
NormalizedType._create(
component,
ctx=ctx,
expand_opaque_union=False,
)
for component in sub_intersection.objects(ctx.env.schema)
]

result_union = component_unions[0]
for component_union in component_unions[1:]:
result_union = [
NormalizedType._intersection_of_intersections(
result_intersection, component_intersection, ctx=ctx
)
for result_intersection in result_union
for component_intersection in component_union
]
return result_union
return NormalizedType._intersection_of_unions(
component_unions, ctx=ctx
)

elif type.is_view(ctx.env.schema):
return NormalizedType._create(
type.peel_view(ctx.env.schema), ctx=ctx
type.peel_view(ctx.env.schema),
ctx=ctx,
expand_opaque_union=False,
)

# Now the left type is a simple type.
return [[type]]

@staticmethod
def intersection_of_normalized(
left_normalized: NormalizedType,
right_normalized: NormalizedType,
*,
ctx: context.ContextLevel,
) -> NormalizedType:

source_union: list[list[s_types.Type]] = (
NormalizedType._intersection_of_unions([
left_normalized.union_of_intersections,
right_normalized.union_of_intersections,
], ctx=ctx)
)
simplified_union: list[list[s_types.Type]] = (
NormalizedType._simplify_union(source_union, ctx=ctx)
)
return NormalizedType(
union_of_intersections=simplified_union
)

@staticmethod
def _intersection_of_unions(
component_unions: list[list[list[s_types.Type]]],
*,
ctx: context.ContextLevel,
) -> list[list[s_types.Type]]:

result_union = component_unions[0]
for component_union in component_unions[1:]:
result_union = [
NormalizedType._intersection_of_intersections(
result_intersection, component_intersection, ctx=ctx
)
for result_intersection in result_union
for component_intersection in component_union
]
return result_union

@staticmethod
def _intersection_of_intersections(
left_intersection: list[s_types.Type],
Expand Down Expand Up @@ -684,6 +681,133 @@ def _intersection_of_intersections(

return result

@staticmethod
def compute_other_intersector(
target: NormalizedType,
intersector: NormalizedType,
*,
ctx: context.ContextLevel,
) -> NormalizedType:

# Compute the intersector required such that its intersection with
# the source intersector is a subclass of the target
# eg. A * X = A*B -> X = B

source_union: list[list[s_types.Type]] = (
NormalizedType._compute_other_intersector(
target.union_of_intersections,
intersector.union_of_intersections,
ctx=ctx,
)
)
simplified_union: list[list[s_types.Type]] = (
NormalizedType._simplify_union(source_union, ctx=ctx)
)
return NormalizedType(
union_of_intersections=simplified_union
)

@staticmethod
def _compute_other_intersector(
target_union: list[list[s_types.Type]],
intersector_union: list[list[s_types.Type]],
*,
ctx: context.ContextLevel,
) -> list[list[s_types.Type]]:

# Expand source union
if len(intersector_union) > 1:
# (A) / (B+C) = A/B * A/C
component_unions = [
NormalizedType._compute_other_intersector(
target_union, [known_intersection], ctx=ctx
)
for known_intersection in intersector_union
]

return NormalizedType._intersection_of_unions(
component_unions, ctx=ctx
)

# Expand target union
if len(target_union) > 1:
# (A+B) / (C) = A/C + B/C
result_union = []
for target_intersection in target_union:
result_union += (
NormalizedType._compute_other_intersector(
[target_intersection], intersector_union, ctx=ctx
)
)
return result_union

# Expand target intersection
if len(target_union[0]) > 1:
# (A*B) / (C) = A/C * B/C
component_unions = [
NormalizedType._compute_other_intersector(
[[target_type]], [intersector_intersection], ctx=ctx
)
for intersector_intersection in intersector_union
for target_type in target_union[0]
]

return NormalizedType._intersection_of_unions(
component_unions, ctx=ctx
)

# Expand source intersection
if len(intersector_union[0]) > 1:
# (A) / (B*C) = A/B + A/C
result_union = []
for intersector_type in intersector_union[0]:
result_union += (
NormalizedType._compute_other_intersector(
target_union, [[intersector_type]], ctx=ctx
)
)
return result_union

# Compute intersector of simple types
target_type = target_union[0][0]
intersector_type = intersector_union[0][0]

if intersector_type.issubclass(ctx.env.schema, target_type):
# a * X = A -> X = anything
return [[]]

return target_union

@staticmethod
def _simplify_union(
source_union: list[list[s_types.Type]],
*,
ctx: context.ContextLevel,
) -> list[list[s_types.Type]]:

# filter out any superclasses
simplified_union: list[list[s_types.Type]] = []
for source_intersection in source_union:
if any(
NormalizedType._is_sub_intersection(
simplified_intersection, source_intersection, ctx=ctx
)
for simplified_intersection in simplified_union
):
# skip subclass of previously seen
pass
else:
# filter out any previous subclasses
simplified_union = [
simplified_intersection
for simplified_intersection in simplified_union
if not NormalizedType._is_sub_intersection(
source_intersection, simplified_intersection, ctx=ctx
)
] + [source_intersection]

return simplified_union

@staticmethod
def _is_sub_intersection(
left_intersection: list[s_types.Type],
Expand All @@ -692,6 +816,9 @@ def _is_sub_intersection(
ctx: context.ContextLevel,
) -> bool:

# Returns whether the right intersection is a 'sub-class' of
# the right intersection.

return all(
any(
right_type.issubclass(ctx.env.schema, left_type)
Expand Down Expand Up @@ -720,14 +847,49 @@ def apply_intersection(

left_normalized = NormalizedType.create(left, ctx=ctx)
right_normalized = NormalizedType.create(right, ctx=ctx)
int_type = NormalizedType.intersection_of_normalized(
int_normalized = NormalizedType.intersection_of_normalized(
left_normalized, right_normalized, ctx=ctx
).as_type(ctx=ctx)
is_subtype = right.issubclass(ctx.env.schema, left)
)

int_type = int_normalized.as_type(ctx=ctx)

# Compute the "minimal" intersector required
difference_normalized = NormalizedType.compute_other_intersector(
int_normalized,
left_normalized,
ctx=ctx,
)

left_name = left.get_displayname(ctx.env.schema)
right_name = right.get_displayname(ctx.env.schema)
left_norm_name = left_normalized.as_type(ctx=ctx).get_displayname(ctx.env.schema)
right_norm_name = right_normalized.as_type(ctx=ctx).get_displayname(ctx.env.schema)
int_name = int_type.get_displayname(ctx.env.schema)

if [] in difference_normalized.union_of_intersections:
# The "minimal" intersector is effectively the set of everything.
# No further narrowing is necessary. This can happen if the left type
# is an opaque type and the right type is a direct subclass.
return TypeIntersectionResult(
stype=int_type,
is_subtype=True,
)

difference_type = difference_normalized.as_type(ctx=ctx)
difference_name = difference_type.get_displayname(ctx.env.schema)

if difference_type.issubclass(ctx.env.schema, left):
return TypeIntersectionResult(
stype=difference_type,
is_subtype=True,
)

result_type = get_intersection_type([left, difference_type], ctx=ctx)
result_name = result_type.get_displayname(ctx.env.schema)

return TypeIntersectionResult(
stype=int_type,
is_subtype=is_subtype,
stype=result_type,
is_subtype=False,
)


Expand Down

0 comments on commit 12cae6b

Please sign in to comment.