Skip to content

Commit

Permalink
Refactor type to typeref to normalize how type expressions are handle…
Browse files Browse the repository at this point in the history
…d. (#7274)

Type expressions converted to typerefs are used to:
- get actual types from the database, and
- get the correct rvar for type intersections.

This PR moves splits these functionalities into non-overlapping fields.
- For getting actual types, the `union` and `union_is_exhaustive` fields are used,
- For getting the correct rvar for type intersection, the new `expr_intersection` and `expr_union` fields are used.

This new structure handles both union and intersection type expressions.
  • Loading branch information
dnwpark committed May 14, 2024
1 parent b91c549 commit c2c6d60
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 100 deletions.
11 changes: 6 additions & 5 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,16 @@ class TypeRef(ImmutableBase):
# A set of type ancestor descriptors, if necessary for
# this type description.
ancestors: typing.Optional[typing.FrozenSet[TypeRef]] = None
# If this is a union type, this would be a set of
# union elements.
# If this is a compound type, this is a non-overlapping set of
# constituent types.
union: typing.Optional[typing.FrozenSet[TypeRef]] = None
# Whether the union is specified by an exhaustive list of
# types, and type inheritance should not be considered.
union_is_exhaustive: bool = False
# If this is an intersection type, this would be a set of
# intersection elements.
intersection: typing.Optional[typing.FrozenSet[TypeRef]] = None
# If this is a complex type, record the expression used to generate the
# type. This is used later to get the correct rvar in `get_path_var`.
expr_intersection: typing.Optional[typing.FrozenSet[TypeRef]] = None
expr_union: typing.Optional[typing.FrozenSet[TypeRef]] = None
# If this node is an element of a collection, and the
# collection elements are named, this would be then
# name of the element.
Expand Down
98 changes: 46 additions & 52 deletions edb/ir/typeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Dict,
Set,
FrozenSet,
cast,
overload,
TYPE_CHECKING,
)
Expand Down Expand Up @@ -173,10 +172,6 @@ def contains_predicate(
if pred(typeref):
return True

elif typeref.intersection:
return any(
contains_predicate(sub, pred) for sub in typeref.intersection
)
elif typeref.union:
return any(
contains_predicate(sub, pred) for sub in typeref.union
Expand Down Expand Up @@ -305,39 +300,34 @@ def _typeref(
)
elif not isinstance(t, s_types.Collection):
assert isinstance(t, s_types.InheritingType)
union_of = t.get_union_of(schema)
union: Optional[FrozenSet[irast.TypeRef]]
if union_of:
assert isinstance(t, s_objtypes.ObjectType)
union_types = {
cast(s_objtypes.ObjectType, c).get_nearest_non_derived_parent(
schema)
for c in union_of.objects(schema)
}
non_overlapping, union_is_exhaustive = (
s_utils.get_non_overlapping_union(schema, union_types)

union: Optional[FrozenSet[irast.TypeRef]] = None
union_is_exhaustive: bool = False
expr_intersection: Optional[FrozenSet[irast.TypeRef]] = None
expr_union: Optional[FrozenSet[irast.TypeRef]] = None
if t.is_union_type(schema) or t.is_intersection_type(schema):
union_types, union_is_exhaustive = (
s_utils.get_type_expr_non_overlapping_union(t, schema)
)
if union_is_exhaustive:
non_overlapping = frozenset({
t for t in non_overlapping
if t.is_material_object_type(schema)
})

union = frozenset(
_typeref(c) for c in non_overlapping
)
else:
union_is_exhaustive = False
union = None

intersection_of = t.get_intersection_of(schema)
intersection: Optional[FrozenSet[irast.TypeRef]]
if intersection_of:
intersection = frozenset(
_typeref(c) for c in intersection_of.objects(schema)
_typeref(c) for c in union_types
)
else:
intersection = None

# Keep track of type expression structure.
# This is necessary to determine the correct rvar when doing
# type intersections or polymorphic queries.
if expr_intersection_types := t.get_intersection_of(schema):
expr_intersection = frozenset(
_typeref(c)
for c in expr_intersection_types.objects(schema)
)

if expr_union_types := t.get_union_of(schema):
expr_union = frozenset(
_typeref(c)
for c in expr_union_types.objects(schema)
)

schema, material_type = t.material_type(schema)

Expand All @@ -358,26 +348,29 @@ def _typeref(
else:
base_typeref = None

children: Optional[FrozenSet[irast.TypeRef]]

if material_typeref is None and include_children:
children: Optional[FrozenSet[irast.TypeRef]] = None
if (
material_typeref is None
and include_children
and children is None
):
children = frozenset(
_typeref(child, include_children=True)
for child in t.children(schema)
if not child.get_is_derived(schema)
and not child.is_compound_type(schema)
)
else:
children = None

ancestors: Optional[FrozenSet[irast.TypeRef]]
if material_typeref is None and include_ancestors:
ancestors: Optional[FrozenSet[irast.TypeRef]] = None
if (
material_typeref is None
and include_ancestors
and ancestors is None
):
ancestors = frozenset(
_typeref(ancestor, include_ancestors=False)
for ancestor in t.get_ancestors(schema).objects(schema)
)
else:
ancestors = None

sql_type = None
needs_custom_json_cast = False
Expand All @@ -401,7 +394,8 @@ def _typeref(
ancestors=ancestors,
union=union,
union_is_exhaustive=union_is_exhaustive,
intersection=intersection,
expr_intersection=expr_intersection,
expr_union=expr_union,
element_name=_name,
is_scalar=t.is_scalar(),
is_abstract=t.get_abstract(schema),
Expand Down Expand Up @@ -912,36 +906,36 @@ def type_contains(
if typeref == parent:
return True

elif typeref.union:
elif typeref.expr_union:
# A union is considered a subtype of a type, if
# ALL its components are subtypes of that type.
return all(
type_contains(parent, component)
for component in typeref.union
for component in typeref.expr_union
)

elif typeref.intersection:
elif typeref.expr_intersection:
# An intersection is considered a subtype of a type, if
# ANY of its components are subtypes of that type.
return any(
type_contains(parent, component)
for component in typeref.intersection
for component in typeref.expr_intersection
)

elif parent.union:
elif parent.expr_union:
# A type is considered a subtype of a union type,
# if it is a subtype of ANY of the union components.
return any(
type_contains(component, typeref)
for component in parent.union
for component in parent.expr_union
)

elif parent.intersection:
elif parent.expr_intersection:
# A type is considered a subtype of an intersection type,
# if it is a subtype of ALL of the intersection components.
return any(
type_contains(component, typeref)
for component in parent.intersection
for component in parent.expr_intersection
)

else:
Expand Down
24 changes: 1 addition & 23 deletions edb/pgsql/compiler/relctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,29 +1726,6 @@ def range_for_typeref(
ctx=ctx,
)

elif typeref.intersection:
wrapper = pgast.SelectStmt()
component_rvars = []
for component in typeref.intersection:
component_rvar = range_for_typeref(
component,
lateral=True,
path_id=path_id,
for_mutation=for_mutation,
dml_source=dml_source,
ctx=ctx,
)
pathctx.put_rvar_path_bond(component_rvar, path_id)
component_rvars.append(component_rvar)
include_rvar(wrapper, component_rvar, path_id, ctx=ctx)

int_rvar = pgast.IntersectionRangeVar(component_rvars=component_rvars)
for aspect in ('source', 'value'):
pathctx.put_path_rvar(wrapper, path_id, int_rvar, aspect=aspect)

pathctx.put_path_bond(wrapper, path_id)
rvar = rvar_for_rel(wrapper, lateral=lateral, typeref=typeref, ctx=ctx)

else:
rvar = range_for_material_objtype(
typeref,
Expand Down Expand Up @@ -1860,6 +1837,7 @@ def range_from_queryset(
# Just one class table, so return it directly
from_rvar = set_ops[0][1].from_clause[0]
assert isinstance(from_rvar, pgast.PathRangeVar)
from_rvar = from_rvar.replace(typeref=typeref)
rvar = from_rvar

return rvar
Expand Down
35 changes: 15 additions & 20 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,20 @@ def process_set_as_path_type_intersection(

assert not rptr.expr, 'type intersection pointer with expr??'

if (not source_is_visible
if ir_set.typeref.union is not None and len(ir_set.typeref.union) == 0:
# If the typeref was a type expression which resolves to no actual
# types, just return an empty set.
empty_ir = irast.Set(
path_id=ir_set.path_id,
typeref=ir_set.typeref,
expr=irast.EmptySet(typeref=ir_set.typeref),
)
source_rvar = relctx.new_empty_rvar(
cast('irast.SetE[irast.EmptySet]', empty_ir),
ctx=ctx)
relctx.include_rvar(stmt, source_rvar, ir_set.path_id, ctx=ctx)

elif (not source_is_visible
and isinstance(ir_source.expr, irast.Pointer)
and not ir_source.path_id.is_type_intersection_path()
and not ir_source.expr.expr
Expand All @@ -950,27 +963,9 @@ def process_set_as_path_type_intersection(

else:
source_rvar = get_set_rvar(ir_source, ctx=ctx)
intersection = ir_set.typeref.intersection
if intersection:
if ir_source.typeref.intersection:
current_intersection = {
t.id for t in ir_source.typeref.intersection
}
else:
current_intersection = {
ir_source.typeref.id
}

intersectors = {t for t in intersection
if t.id not in current_intersection}

assert len(intersectors) == 1
target_typeref = next(iter(intersectors))
else:
target_typeref = rptr.ptrref.out_target

poly_rvar = relctx.range_for_typeref(
target_typeref,
rptr.ptrref.out_target,
path_id=ir_set.path_id,
dml_source=irutils.get_dml_sources(ir_set),
lateral=True,
Expand Down
78 changes: 78 additions & 0 deletions edb/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,84 @@ def get_non_overlapping_union(
return frozenset(all_objects), True


def get_type_expr_non_overlapping_union(
type: s_types.Type,
schema: s_schema.Schema,
) -> Tuple[FrozenSet[s_types.Type], bool]:
"""Get a non-overlapping set of the type's descendants"""

from edb.schema import types as s_types

expanded_types = expand_type_expr_descendants(type, schema)

# filter out subclasses
expanded_types = {
type
for type in expanded_types
if not any(
type is not other and type.issubclass(schema, other)
for other in expanded_types
)
}

non_overlapping, union_is_exhaustive = get_non_overlapping_union(
schema, cast(set[so.InheritingObject], expanded_types)
)

return cast(FrozenSet[s_types.Type], non_overlapping), union_is_exhaustive


def expand_type_expr_descendants(
type: s_types.Type,
schema: s_schema.Schema,
*,
expand_opaque_union: bool = True,
) -> set[s_types.Type]:
"""Expand types and type expressions to get descendants"""

from edb.schema import types as s_types

if sub_union := type.get_union_of(schema):
# Expanding a union
# Get the union of the component descendants
return set.union(*(
expand_type_expr_descendants(
component, schema,
)
for component in sub_union.objects(schema)
))

elif sub_intersection := type.get_intersection_of(schema):
# Expanding an intersection
# Get the intersection of component descendants
return set.intersection(*(
expand_type_expr_descendants(
component, schema
)
for component in sub_intersection.objects(schema)
))

elif type.is_view(schema):
# When expanding a view, simply unpeel the view.
return expand_type_expr_descendants(
type.peel_view(schema), schema
)

# Return simple type and all its descendants.
# Some types (eg. BaseObject) have non-simple descendants, filter them out.
return {type} | {
c for c in cast(
set[s_types.Type],
set(cast(so.InheritingObject, type).descendants(schema))
)
if (
not c.is_union_type(schema)
and not c.is_intersection_type(schema)
and not c.is_view(schema)
)
}


def _union_error(
schema: s_schema.Schema, components: Iterable[s_types.Type]
) -> errors.SchemaError:
Expand Down

0 comments on commit c2c6d60

Please sign in to comment.