Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor converting type expressions to typerefs and ranges #7274

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,16 @@ class TypeRef(ImmutableBase):
# A set of type ancestor descriptors, if necessary for
# this type description.
ancestors: typing.Optional[typing.FrozenSet[TypeRef]] = None
# If this is a union type, this would be a set of
# union elements.
# If this is a compound type, this is a non-overlapping set of
# constituent types.
union: typing.Optional[typing.FrozenSet[TypeRef]] = None
# Whether the union is specified by an exhaustive list of
# types, and type inheritance should not be considered.
union_is_exhaustive: bool = False
# If this is an intersection type, this would be a set of
# intersection elements.
intersection: typing.Optional[typing.FrozenSet[TypeRef]] = None
Comment on lines -160 to -162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It weirds me out a little if we aren't going to track that a type is an intersection in the TypeRef, even when it is one in the schema.

This might not actually matter though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit odd, but the typeref doesn't exactly map onto types anyways

# If this is a complex type, record the expression used to generate the
# type. This is used later to get the correct rvar in `get_path_var`.
expr_intersection: typing.Optional[typing.FrozenSet[TypeRef]] = None
expr_union: typing.Optional[typing.FrozenSet[TypeRef]] = None
# If this node is an element of a collection, and the
# collection elements are named, this would be then
# name of the element.
Expand Down
98 changes: 46 additions & 52 deletions edb/ir/typeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Dict,
Set,
FrozenSet,
cast,
overload,
TYPE_CHECKING,
)
Expand Down Expand Up @@ -173,10 +172,6 @@ def contains_predicate(
if pred(typeref):
return True

elif typeref.intersection:
return any(
contains_predicate(sub, pred) for sub in typeref.intersection
)
elif typeref.union:
return any(
contains_predicate(sub, pred) for sub in typeref.union
Expand Down Expand Up @@ -305,39 +300,34 @@ def _typeref(
)
elif not isinstance(t, s_types.Collection):
assert isinstance(t, s_types.InheritingType)
union_of = t.get_union_of(schema)
union: Optional[FrozenSet[irast.TypeRef]]
if union_of:
assert isinstance(t, s_objtypes.ObjectType)
union_types = {
cast(s_objtypes.ObjectType, c).get_nearest_non_derived_parent(
schema)
for c in union_of.objects(schema)
}
non_overlapping, union_is_exhaustive = (
s_utils.get_non_overlapping_union(schema, union_types)

union: Optional[FrozenSet[irast.TypeRef]] = None
union_is_exhaustive: bool = False
expr_intersection: Optional[FrozenSet[irast.TypeRef]] = None
expr_union: Optional[FrozenSet[irast.TypeRef]] = None
if t.is_union_type(schema) or t.is_intersection_type(schema):
union_types, union_is_exhaustive = (
s_utils.get_type_expr_non_overlapping_union(t, schema)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does normalizing help here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All normalization stuff removed.

if union_is_exhaustive:
non_overlapping = frozenset({
t for t in non_overlapping
if t.is_material_object_type(schema)
})

union = frozenset(
_typeref(c) for c in non_overlapping
)
else:
union_is_exhaustive = False
union = None

intersection_of = t.get_intersection_of(schema)
intersection: Optional[FrozenSet[irast.TypeRef]]
if intersection_of:
intersection = frozenset(
_typeref(c) for c in intersection_of.objects(schema)
_typeref(c) for c in union_types
)
else:
intersection = None

# Keep track of type expression structure.
# This is necessary to determine the correct rvar when doing
# type intersections or polymorphic queries.
if expr_intersection_types := t.get_intersection_of(schema):
expr_intersection = frozenset(
_typeref(c)
for c in expr_intersection_types.objects(schema)
)

if expr_union_types := t.get_union_of(schema):
expr_union = frozenset(
_typeref(c)
for c in expr_union_types.objects(schema)
)

schema, material_type = t.material_type(schema)

Expand All @@ -358,26 +348,29 @@ def _typeref(
else:
base_typeref = None

children: Optional[FrozenSet[irast.TypeRef]]

if material_typeref is None and include_children:
children: Optional[FrozenSet[irast.TypeRef]] = None
if (
material_typeref is None
and include_children
and children is None
):
children = frozenset(
_typeref(child, include_children=True)
for child in t.children(schema)
if not child.get_is_derived(schema)
and not child.is_compound_type(schema)
)
else:
children = None

ancestors: Optional[FrozenSet[irast.TypeRef]]
if material_typeref is None and include_ancestors:
ancestors: Optional[FrozenSet[irast.TypeRef]] = None
if (
material_typeref is None
and include_ancestors
and ancestors is None
):
ancestors = frozenset(
_typeref(ancestor, include_ancestors=False)
for ancestor in t.get_ancestors(schema).objects(schema)
)
else:
ancestors = None

sql_type = None
needs_custom_json_cast = False
Expand All @@ -401,7 +394,8 @@ def _typeref(
ancestors=ancestors,
union=union,
union_is_exhaustive=union_is_exhaustive,
intersection=intersection,
expr_intersection=expr_intersection,
expr_union=expr_union,
element_name=_name,
is_scalar=t.is_scalar(),
is_abstract=t.get_abstract(schema),
Expand Down Expand Up @@ -912,36 +906,36 @@ def type_contains(
if typeref == parent:
return True

elif typeref.union:
elif typeref.expr_union:
# A union is considered a subtype of a type, if
# ALL its components are subtypes of that type.
return all(
type_contains(parent, component)
for component in typeref.union
for component in typeref.expr_union
)

elif typeref.intersection:
elif typeref.expr_intersection:
# An intersection is considered a subtype of a type, if
# ANY of its components are subtypes of that type.
return any(
type_contains(parent, component)
for component in typeref.intersection
for component in typeref.expr_intersection
)

elif parent.union:
elif parent.expr_union:
# A type is considered a subtype of a union type,
# if it is a subtype of ANY of the union components.
return any(
type_contains(component, typeref)
for component in parent.union
for component in parent.expr_union
)

elif parent.intersection:
elif parent.expr_intersection:
# A type is considered a subtype of an intersection type,
# if it is a subtype of ALL of the intersection components.
return any(
type_contains(component, typeref)
for component in parent.intersection
for component in parent.expr_intersection
)

else:
Expand Down
24 changes: 1 addition & 23 deletions edb/pgsql/compiler/relctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,29 +1726,6 @@ def range_for_typeref(
ctx=ctx,
)

elif typeref.intersection:
wrapper = pgast.SelectStmt()
component_rvars = []
for component in typeref.intersection:
component_rvar = range_for_typeref(
component,
lateral=True,
path_id=path_id,
for_mutation=for_mutation,
dml_source=dml_source,
ctx=ctx,
)
pathctx.put_rvar_path_bond(component_rvar, path_id)
component_rvars.append(component_rvar)
include_rvar(wrapper, component_rvar, path_id, ctx=ctx)

int_rvar = pgast.IntersectionRangeVar(component_rvars=component_rvars)
for aspect in ('source', 'value'):
pathctx.put_path_rvar(wrapper, path_id, int_rvar, aspect=aspect)

pathctx.put_path_bond(wrapper, path_id)
rvar = rvar_for_rel(wrapper, lateral=lateral, typeref=typeref, ctx=ctx)

else:
rvar = range_for_material_objtype(
typeref,
Expand Down Expand Up @@ -1860,6 +1837,7 @@ def range_from_queryset(
# Just one class table, so return it directly
from_rvar = set_ops[0][1].from_clause[0]
assert isinstance(from_rvar, pgast.PathRangeVar)
from_rvar = from_rvar.replace(typeref=typeref)
rvar = from_rvar

return rvar
Expand Down
35 changes: 15 additions & 20 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,20 @@ def process_set_as_path_type_intersection(

assert not rptr.expr, 'type intersection pointer with expr??'

if (not source_is_visible
if ir_set.typeref.union is not None and len(ir_set.typeref.union) == 0:
# If the typeref was a type expression which resolves to no actual
# types, just return an empty set.
empty_ir = irast.Set(
path_id=ir_set.path_id,
typeref=ir_set.typeref,
expr=irast.EmptySet(typeref=ir_set.typeref),
)
source_rvar = relctx.new_empty_rvar(
cast('irast.SetE[irast.EmptySet]', empty_ir),
ctx=ctx)
relctx.include_rvar(stmt, source_rvar, ir_set.path_id, ctx=ctx)

elif (not source_is_visible
and isinstance(ir_source.expr, irast.Pointer)
and not ir_source.path_id.is_type_intersection_path()
and not ir_source.expr.expr
Expand All @@ -950,27 +963,9 @@ def process_set_as_path_type_intersection(

else:
source_rvar = get_set_rvar(ir_source, ctx=ctx)
intersection = ir_set.typeref.intersection
if intersection:
if ir_source.typeref.intersection:
current_intersection = {
t.id for t in ir_source.typeref.intersection
}
else:
current_intersection = {
ir_source.typeref.id
}

intersectors = {t for t in intersection
if t.id not in current_intersection}

assert len(intersectors) == 1
target_typeref = next(iter(intersectors))
else:
target_typeref = rptr.ptrref.out_target

poly_rvar = relctx.range_for_typeref(
target_typeref,
rptr.ptrref.out_target,
path_id=ir_set.path_id,
dml_source=irutils.get_dml_sources(ir_set),
lateral=True,
Expand Down
78 changes: 78 additions & 0 deletions edb/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,84 @@ def get_non_overlapping_union(
return frozenset(all_objects), True


def get_type_expr_non_overlapping_union(
type: s_types.Type,
schema: s_schema.Schema,
) -> Tuple[FrozenSet[s_types.Type], bool]:
"""Get a non-overlapping set of the type's descendants"""

from edb.schema import types as s_types

expanded_types = expand_type_expr_descendants(type, schema)

# filter out subclasses
expanded_types = {
type
for type in expanded_types
if not any(
type is not other and type.issubclass(schema, other)
for other in expanded_types
)
}
Comment on lines +1184 to +1192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? I guess it helps us not always require union_is_exhaustive?

Copy link
Contributor Author

@dnwpark dnwpark May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. It also makes reading the sql easier too :)


non_overlapping, union_is_exhaustive = get_non_overlapping_union(
schema, cast(set[so.InheritingObject], expanded_types)
)

return cast(FrozenSet[s_types.Type], non_overlapping), union_is_exhaustive


def expand_type_expr_descendants(
type: s_types.Type,
schema: s_schema.Schema,
*,
expand_opaque_union: bool = True,
) -> set[s_types.Type]:
"""Expand types and type expressions to get descendants"""

from edb.schema import types as s_types

if sub_union := type.get_union_of(schema):
# Expanding a union
# Get the union of the component descendants
return set.union(*(
expand_type_expr_descendants(
component, schema,
)
for component in sub_union.objects(schema)
))

elif sub_intersection := type.get_intersection_of(schema):
# Expanding an intersection
# Get the intersection of component descendants
return set.intersection(*(
expand_type_expr_descendants(
component, schema
)
for component in sub_intersection.objects(schema)
))

elif type.is_view(schema):
# When expanding a view, simply unpeel the view.
return expand_type_expr_descendants(
type.peel_view(schema), schema
)

# Return simple type and all its descendants.
# Some types (eg. BaseObject) have non-simple descendants, filter them out.
return {type} | {
c for c in cast(
set[s_types.Type],
set(cast(so.InheritingObject, type).descendants(schema))
)
if (
not c.is_union_type(schema)
and not c.is_intersection_type(schema)
and not c.is_view(schema)
)
}


def _union_error(
schema: s_schema.Schema, components: Iterable[s_types.Type]
) -> errors.SchemaError:
Expand Down