From 759123d7e1f1a0e1d9fa21760d7fcc601173b5ba Mon Sep 17 00:00:00 2001 From: dnwpark Date: Sat, 11 May 2024 19:42:21 -0400 Subject: [PATCH] Refactor type to typeref to normalize how type expressions are handled. --- edb/ir/ast.py | 11 ++-- edb/ir/typeutils.py | 98 +++++++++++++++++------------------- edb/pgsql/compiler/relctx.py | 24 +-------- edb/pgsql/compiler/relgen.py | 35 ++++++------- edb/schema/utils.py | 78 ++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 100 deletions(-) diff --git a/edb/ir/ast.py b/edb/ir/ast.py index d1095a772a2..a7373091bda 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -151,15 +151,16 @@ class TypeRef(ImmutableBase): # A set of type ancestor descriptors, if necessary for # this type description. ancestors: typing.Optional[typing.FrozenSet[TypeRef]] = None - # If this is a union type, this would be a set of - # union elements. + # If this is a compound type, this is a non-overlapping set of + # constituent types. union: typing.Optional[typing.FrozenSet[TypeRef]] = None # Whether the union is specified by an exhaustive list of # types, and type inheritance should not be considered. union_is_exhaustive: bool = False - # If this is an intersection type, this would be a set of - # intersection elements. - intersection: typing.Optional[typing.FrozenSet[TypeRef]] = None + # If this is a complex type, record the expression used to generate the + # type. This is used later to get the correct rvar in `get_path_var`. + expr_intersection: typing.Optional[typing.FrozenSet[TypeRef]] = None + expr_union: typing.Optional[typing.FrozenSet[TypeRef]] = None # If this node is an element of a collection, and the # collection elements are named, this would be then # name of the element. diff --git a/edb/ir/typeutils.py b/edb/ir/typeutils.py index 2bd3dbc14a8..6a5ae22bbd2 100644 --- a/edb/ir/typeutils.py +++ b/edb/ir/typeutils.py @@ -28,7 +28,6 @@ Dict, Set, FrozenSet, - cast, overload, TYPE_CHECKING, ) @@ -173,10 +172,6 @@ def contains_predicate( if pred(typeref): return True - elif typeref.intersection: - return any( - contains_predicate(sub, pred) for sub in typeref.intersection - ) elif typeref.union: return any( contains_predicate(sub, pred) for sub in typeref.union @@ -305,39 +300,34 @@ def _typeref( ) elif not isinstance(t, s_types.Collection): assert isinstance(t, s_types.InheritingType) - union_of = t.get_union_of(schema) - union: Optional[FrozenSet[irast.TypeRef]] - if union_of: - assert isinstance(t, s_objtypes.ObjectType) - union_types = { - cast(s_objtypes.ObjectType, c).get_nearest_non_derived_parent( - schema) - for c in union_of.objects(schema) - } - non_overlapping, union_is_exhaustive = ( - s_utils.get_non_overlapping_union(schema, union_types) + + union: Optional[FrozenSet[irast.TypeRef]] = None + union_is_exhaustive: bool = False + expr_intersection: Optional[FrozenSet[irast.TypeRef]] = None + expr_union: Optional[FrozenSet[irast.TypeRef]] = None + if t.is_union_type(schema) or t.is_intersection_type(schema): + union_types, union_is_exhaustive = ( + s_utils.get_type_expr_non_overlapping_union(t, schema) ) - if union_is_exhaustive: - non_overlapping = frozenset({ - t for t in non_overlapping - if t.is_material_object_type(schema) - }) union = frozenset( - _typeref(c) for c in non_overlapping - ) - else: - union_is_exhaustive = False - union = None - - intersection_of = t.get_intersection_of(schema) - intersection: Optional[FrozenSet[irast.TypeRef]] - if intersection_of: - intersection = frozenset( - _typeref(c) for c in intersection_of.objects(schema) + _typeref(c) for c in union_types ) - else: - intersection = None + + # Keep track of type expression structure. + # This is necessary to determine the correct rvar when doing + # type intersections or polymorphic queries. + if expr_intersection_types := t.get_intersection_of(schema): + expr_intersection = frozenset( + _typeref(c) + for c in expr_intersection_types.objects(schema) + ) + + if expr_union_types := t.get_union_of(schema): + expr_union = frozenset( + _typeref(c) + for c in expr_union_types.objects(schema) + ) schema, material_type = t.material_type(schema) @@ -358,26 +348,29 @@ def _typeref( else: base_typeref = None - children: Optional[FrozenSet[irast.TypeRef]] - - if material_typeref is None and include_children: + children: Optional[FrozenSet[irast.TypeRef]] = None + if ( + material_typeref is None + and include_children + and children is None + ): children = frozenset( _typeref(child, include_children=True) for child in t.children(schema) if not child.get_is_derived(schema) and not child.is_compound_type(schema) ) - else: - children = None - ancestors: Optional[FrozenSet[irast.TypeRef]] - if material_typeref is None and include_ancestors: + ancestors: Optional[FrozenSet[irast.TypeRef]] = None + if ( + material_typeref is None + and include_ancestors + and ancestors is None + ): ancestors = frozenset( _typeref(ancestor, include_ancestors=False) for ancestor in t.get_ancestors(schema).objects(schema) ) - else: - ancestors = None sql_type = None needs_custom_json_cast = False @@ -401,7 +394,8 @@ def _typeref( ancestors=ancestors, union=union, union_is_exhaustive=union_is_exhaustive, - intersection=intersection, + expr_intersection=expr_intersection, + expr_union=expr_union, element_name=_name, is_scalar=t.is_scalar(), is_abstract=t.get_abstract(schema), @@ -912,36 +906,36 @@ def type_contains( if typeref == parent: return True - elif typeref.union: + elif typeref.expr_union: # A union is considered a subtype of a type, if # ALL its components are subtypes of that type. return all( type_contains(parent, component) - for component in typeref.union + for component in typeref.expr_union ) - elif typeref.intersection: + elif typeref.expr_intersection: # An intersection is considered a subtype of a type, if # ANY of its components are subtypes of that type. return any( type_contains(parent, component) - for component in typeref.intersection + for component in typeref.expr_intersection ) - elif parent.union: + elif parent.expr_union: # A type is considered a subtype of a union type, # if it is a subtype of ANY of the union components. return any( type_contains(component, typeref) - for component in parent.union + for component in parent.expr_union ) - elif parent.intersection: + elif parent.expr_intersection: # A type is considered a subtype of an intersection type, # if it is a subtype of ALL of the intersection components. return any( type_contains(component, typeref) - for component in parent.intersection + for component in parent.expr_intersection ) else: diff --git a/edb/pgsql/compiler/relctx.py b/edb/pgsql/compiler/relctx.py index 555f17b921a..51b4b8b7eb5 100644 --- a/edb/pgsql/compiler/relctx.py +++ b/edb/pgsql/compiler/relctx.py @@ -1726,29 +1726,6 @@ def range_for_typeref( ctx=ctx, ) - elif typeref.intersection: - wrapper = pgast.SelectStmt() - component_rvars = [] - for component in typeref.intersection: - component_rvar = range_for_typeref( - component, - lateral=True, - path_id=path_id, - for_mutation=for_mutation, - dml_source=dml_source, - ctx=ctx, - ) - pathctx.put_rvar_path_bond(component_rvar, path_id) - component_rvars.append(component_rvar) - include_rvar(wrapper, component_rvar, path_id, ctx=ctx) - - int_rvar = pgast.IntersectionRangeVar(component_rvars=component_rvars) - for aspect in ('source', 'value'): - pathctx.put_path_rvar(wrapper, path_id, int_rvar, aspect=aspect) - - pathctx.put_path_bond(wrapper, path_id) - rvar = rvar_for_rel(wrapper, lateral=lateral, typeref=typeref, ctx=ctx) - else: rvar = range_for_material_objtype( typeref, @@ -1860,6 +1837,7 @@ def range_from_queryset( # Just one class table, so return it directly from_rvar = set_ops[0][1].from_clause[0] assert isinstance(from_rvar, pgast.PathRangeVar) + from_rvar = from_rvar.replace(typeref=typeref) rvar = from_rvar return rvar diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 6d473e3b784..7ac12229e3c 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -926,7 +926,20 @@ def process_set_as_path_type_intersection( assert not rptr.expr, 'type intersection pointer with expr??' - if (not source_is_visible + if ir_set.typeref.union is not None and len(ir_set.typeref.union) == 0: + # If the typeref was a type expression which resolves to no actual + # types, just return an empty set. + empty_ir = irast.Set( + path_id=ir_set.path_id, + typeref=ir_set.typeref, + expr=irast.EmptySet(typeref=ir_set.typeref), + ) + source_rvar = relctx.new_empty_rvar( + cast('irast.SetE[irast.EmptySet]', empty_ir), + ctx=ctx) + relctx.include_rvar(stmt, source_rvar, ir_set.path_id, ctx=ctx) + + elif (not source_is_visible and isinstance(ir_source.expr, irast.Pointer) and not ir_source.path_id.is_type_intersection_path() and not ir_source.expr.expr @@ -950,27 +963,9 @@ def process_set_as_path_type_intersection( else: source_rvar = get_set_rvar(ir_source, ctx=ctx) - intersection = ir_set.typeref.intersection - if intersection: - if ir_source.typeref.intersection: - current_intersection = { - t.id for t in ir_source.typeref.intersection - } - else: - current_intersection = { - ir_source.typeref.id - } - - intersectors = {t for t in intersection - if t.id not in current_intersection} - - assert len(intersectors) == 1 - target_typeref = next(iter(intersectors)) - else: - target_typeref = rptr.ptrref.out_target poly_rvar = relctx.range_for_typeref( - target_typeref, + rptr.ptrref.out_target, path_id=ir_set.path_id, dml_source=irutils.get_dml_sources(ir_set), lateral=True, diff --git a/edb/schema/utils.py b/edb/schema/utils.py index 64db48f226d..26a2d57b0e8 100644 --- a/edb/schema/utils.py +++ b/edb/schema/utils.py @@ -1171,6 +1171,84 @@ def get_non_overlapping_union( return frozenset(all_objects), True +def get_type_expr_non_overlapping_union( + type: s_types.Type, + schema: s_schema.Schema, +) -> Tuple[FrozenSet[s_types.Type], bool]: + """Get a non-overlapping set of the type's descendants""" + + from edb.schema import types as s_types + + expanded_types = expand_type_expr_descendants(type, schema) + + # filter out subclasses + expanded_types = { + type + for type in expanded_types + if not any( + type is not other and type.issubclass(schema, other) + for other in expanded_types + ) + } + + non_overlapping, union_is_exhaustive = get_non_overlapping_union( + schema, cast(set[so.InheritingObject], expanded_types) + ) + + return cast(FrozenSet[s_types.Type], non_overlapping), union_is_exhaustive + + +def expand_type_expr_descendants( + type: s_types.Type, + schema: s_schema.Schema, + *, + expand_opaque_union: bool = True, +) -> set[s_types.Type]: + """Expand types and type expressions to get descendants""" + + from edb.schema import types as s_types + + if sub_union := type.get_union_of(schema): + # Expanding a union + # Get the union of the component descendants + return set.union(*( + expand_type_expr_descendants( + component, schema, + ) + for component in sub_union.objects(schema) + )) + + elif sub_intersection := type.get_intersection_of(schema): + # Expanding an intersection + # Get the intersection of component descendants + return set.intersection(*( + expand_type_expr_descendants( + component, schema + ) + for component in sub_intersection.objects(schema) + )) + + elif type.is_view(schema): + # When expanding a view, simply unpeel the view. + return expand_type_expr_descendants( + type.peel_view(schema), schema + ) + + # Return simple type and all its descendants. + # Some types (eg. BaseObject) have non-simple descendants, filter them out. + return {type} | { + c for c in cast( + set[s_types.Type], + set(cast(so.InheritingObject, type).descendants(schema)) + ) + if ( + not c.is_union_type(schema) + and not c.is_intersection_type(schema) + and not c.is_view(schema) + ) + } + + def _union_error( schema: s_schema.Schema, components: Iterable[s_types.Type] ) -> errors.SchemaError: