diff --git a/edb/edgeql/compiler/schemactx.py b/edb/edgeql/compiler/schemactx.py index 385c81a0a88..6bbea5abce8 100644 --- a/edb/edgeql/compiler/schemactx.py +++ b/edb/edgeql/compiler/schemactx.py @@ -535,85 +535,44 @@ def create( ) -> NormalizedType: source_union: list[list[s_types.Type]] = ( - NormalizedType._create(type, ctx=ctx) - ) - simplified_union: list[list[s_types.Type]] = ( - NormalizedType._simplify_intersections(source_union, ctx=ctx) - ) - return NormalizedType( - union_of_intersections=simplified_union - ) - - @staticmethod - def intersection_of_normalized( - left_disjunction: NormalizedType, - right_disjunction: NormalizedType, - *, - ctx: context.ContextLevel, - ) -> NormalizedType: - - source_union: list[list[s_types.Type]] = [ - NormalizedType._intersection_of_intersections( - leftt_intersection, right_intersection, ctx=ctx - ) - for leftt_intersection in ( - left_disjunction.union_of_intersections + NormalizedType._create( + type, + ctx=ctx, + expand_opaque_union=True ) - for right_intersection in ( - right_disjunction.union_of_intersections - ) - ] + ) simplified_union: list[list[s_types.Type]] = ( - NormalizedType._simplify_intersections(source_union, ctx=ctx) + NormalizedType._simplify_union(source_union, ctx=ctx) ) return NormalizedType( union_of_intersections=simplified_union ) - @staticmethod - def _simplify_intersections( - source_union: list[list[s_types.Type]], - *, - ctx: context.ContextLevel, - ) -> list[list[s_types.Type]]: - - # filter out any superclasses - simplified_union: list[list[s_types.Type]] = [] - for source_intersection in source_union: - if any( - NormalizedType._is_sub_intersection( - simplified_intersection, source_intersection, ctx=ctx - ) - for simplified_intersection in simplified_union - ): - # skip subclass of previously seen - pass - else: - # filter out any previous subclasses - simplified_union = [ - simplified_intersection - for simplified_intersection in simplified_union - if not NormalizedType._is_sub_intersection( - source_intersection, simplified_intersection, ctx=ctx - ) - ] + [source_intersection] - - return simplified_union - @staticmethod def _create( type: s_types.Type, *, ctx: context.ContextLevel, + expand_opaque_union: bool, ) -> list[list[s_types.Type]]: - result_union: list[list[s_types.Type]] = [] - - if sub_union := type.get_union_of(ctx.env.schema): + if ( + ( + not type.get_is_opaque_union(ctx.env.schema) + or expand_opaque_union or True + ) + and (sub_union := type.get_union_of(ctx.env.schema)) + ): # simply expand all sub-intersections + result_union: list[list[s_types.Type]] = [] + for component in sub_union.objects(ctx.env.schema): result_union += ( - NormalizedType._create(component, ctx=ctx) + NormalizedType._create( + component, + ctx=ctx, + expand_opaque_union=False, + ) ) return result_union @@ -621,29 +580,67 @@ def _create( # Intersection needs to produce the cross intersection of all # results component_unions: list[list[list[s_types.Type]]] = [ - NormalizedType._create(component, ctx=ctx) + NormalizedType._create( + component, + ctx=ctx, + expand_opaque_union=False, + ) for component in sub_intersection.objects(ctx.env.schema) ] - result_union = component_unions[0] - for component_union in component_unions[1:]: - result_union = [ - NormalizedType._intersection_of_intersections( - result_intersection, component_intersection, ctx=ctx - ) - for result_intersection in result_union - for component_intersection in component_union - ] - return result_union + return NormalizedType._intersection_of_unions( + component_unions, ctx=ctx + ) elif type.is_view(ctx.env.schema): return NormalizedType._create( - type.peel_view(ctx.env.schema), ctx=ctx + type.peel_view(ctx.env.schema), + ctx=ctx, + expand_opaque_union=False, ) # Now the left type is a simple type. return [[type]] + @staticmethod + def intersection_of_normalized( + left_normalized: NormalizedType, + right_normalized: NormalizedType, + *, + ctx: context.ContextLevel, + ) -> NormalizedType: + + source_union: list[list[s_types.Type]] = ( + NormalizedType._intersection_of_unions([ + left_normalized.union_of_intersections, + right_normalized.union_of_intersections, + ], ctx=ctx) + ) + simplified_union: list[list[s_types.Type]] = ( + NormalizedType._simplify_union(source_union, ctx=ctx) + ) + return NormalizedType( + union_of_intersections=simplified_union + ) + + @staticmethod + def _intersection_of_unions( + component_unions: list[list[list[s_types.Type]]], + *, + ctx: context.ContextLevel, + ) -> list[list[s_types.Type]]: + + result_union = component_unions[0] + for component_union in component_unions[1:]: + result_union = [ + NormalizedType._intersection_of_intersections( + result_intersection, component_intersection, ctx=ctx + ) + for result_intersection in result_union + for component_intersection in component_union + ] + return result_union + @staticmethod def _intersection_of_intersections( left_intersection: list[s_types.Type], @@ -684,6 +681,133 @@ def _intersection_of_intersections( return result + @staticmethod + def compute_other_intersector( + target: NormalizedType, + intersector: NormalizedType, + *, + ctx: context.ContextLevel, + ) -> NormalizedType: + + # Compute the intersector required such that its intersection with + # the source intersector is a subclass of the target + # eg. A * X = A*B -> X = B + + source_union: list[list[s_types.Type]] = ( + NormalizedType._compute_other_intersector( + target.union_of_intersections, + intersector.union_of_intersections, + ctx=ctx, + ) + ) + simplified_union: list[list[s_types.Type]] = ( + NormalizedType._simplify_union(source_union, ctx=ctx) + ) + return NormalizedType( + union_of_intersections=simplified_union + ) + + @staticmethod + def _compute_other_intersector( + target_union: list[list[s_types.Type]], + intersector_union: list[list[s_types.Type]], + *, + ctx: context.ContextLevel, + ) -> list[list[s_types.Type]]: + + # Expand source union + if len(intersector_union) > 1: + # (A) / (B+C) = A/B * A/C + component_unions = [ + NormalizedType._compute_other_intersector( + target_union, [known_intersection], ctx=ctx + ) + for known_intersection in intersector_union + ] + + return NormalizedType._intersection_of_unions( + component_unions, ctx=ctx + ) + + # Expand target union + if len(target_union) > 1: + # (A+B) / (C) = A/C + B/C + result_union = [] + for target_intersection in target_union: + result_union += ( + NormalizedType._compute_other_intersector( + [target_intersection], intersector_union, ctx=ctx + ) + ) + return result_union + + # Expand target intersection + if len(target_union[0]) > 1: + # (A*B) / (C) = A/C * B/C + component_unions = [ + NormalizedType._compute_other_intersector( + [[target_type]], [intersector_intersection], ctx=ctx + ) + for intersector_intersection in intersector_union + for target_type in target_union[0] + ] + + return NormalizedType._intersection_of_unions( + component_unions, ctx=ctx + ) + + # Expand source intersection + if len(intersector_union[0]) > 1: + # (A) / (B*C) = A/B + A/C + result_union = [] + for intersector_type in intersector_union[0]: + result_union += ( + NormalizedType._compute_other_intersector( + target_union, [[intersector_type]], ctx=ctx + ) + ) + return result_union + + # Compute intersector of simple types + target_type = target_union[0][0] + intersector_type = intersector_union[0][0] + + if intersector_type.issubclass(ctx.env.schema, target_type): + # a * X = A -> X = anything + return [[]] + + return target_union + + @staticmethod + def _simplify_union( + source_union: list[list[s_types.Type]], + *, + ctx: context.ContextLevel, + ) -> list[list[s_types.Type]]: + + # filter out any superclasses + simplified_union: list[list[s_types.Type]] = [] + for source_intersection in source_union: + if any( + NormalizedType._is_sub_intersection( + simplified_intersection, source_intersection, ctx=ctx + ) + for simplified_intersection in simplified_union + ): + # skip subclass of previously seen + pass + else: + # filter out any previous subclasses + simplified_union = [ + simplified_intersection + for simplified_intersection in simplified_union + if not NormalizedType._is_sub_intersection( + source_intersection, simplified_intersection, ctx=ctx + ) + ] + [source_intersection] + + return simplified_union + @staticmethod def _is_sub_intersection( left_intersection: list[s_types.Type], @@ -692,6 +816,9 @@ def _is_sub_intersection( ctx: context.ContextLevel, ) -> bool: + # Returns whether the right intersection is a 'sub-class' of + # the right intersection. + return all( any( right_type.issubclass(ctx.env.schema, left_type) @@ -720,14 +847,49 @@ def apply_intersection( left_normalized = NormalizedType.create(left, ctx=ctx) right_normalized = NormalizedType.create(right, ctx=ctx) - int_type = NormalizedType.intersection_of_normalized( + int_normalized = NormalizedType.intersection_of_normalized( left_normalized, right_normalized, ctx=ctx - ).as_type(ctx=ctx) - is_subtype = right.issubclass(ctx.env.schema, left) + ) + + int_type = int_normalized.as_type(ctx=ctx) + + # Compute the "minimal" intersector required + difference_normalized = NormalizedType.compute_other_intersector( + int_normalized, + left_normalized, + ctx=ctx, + ) + + left_name = left.get_displayname(ctx.env.schema) + right_name = right.get_displayname(ctx.env.schema) + left_norm_name = left_normalized.as_type(ctx=ctx).get_displayname(ctx.env.schema) + right_norm_name = right_normalized.as_type(ctx=ctx).get_displayname(ctx.env.schema) + int_name = int_type.get_displayname(ctx.env.schema) + + if [] in difference_normalized.union_of_intersections: + # The "minimal" intersector is effectively the set of everything. + # No further narrowing is necessary. This can happen if the left type + # is an opaque type and the right type is a direct subclass. + return TypeIntersectionResult( + stype=int_type, + is_subtype=True, + ) + + difference_type = difference_normalized.as_type(ctx=ctx) + difference_name = difference_type.get_displayname(ctx.env.schema) + + if difference_type.issubclass(ctx.env.schema, left): + return TypeIntersectionResult( + stype=difference_type, + is_subtype=True, + ) + + result_type = get_intersection_type([left, difference_type], ctx=ctx) + result_name = result_type.get_displayname(ctx.env.schema) return TypeIntersectionResult( - stype=int_type, - is_subtype=is_subtype, + stype=result_type, + is_subtype=False, )