diff --git a/edb/edgeql/compiler/func.py b/edb/edgeql/compiler/func.py index 8b9a1d91675..caf4aa88863 100644 --- a/edb/edgeql/compiler/func.py +++ b/edb/edgeql/compiler/func.py @@ -1086,6 +1086,12 @@ def _validate_object_search_call( idx = _validate_has_object_index( variant, schema, span, context, index_name) indexes[typegen.type_to_typeref(variant, ctx.env)] = idx + elif intersection_variants := stype.get_intersection_of(schema): + for variant in intersection_variants.objects(schema): + schema, variant = variant.material_type(schema) + idx = _validate_has_object_index( + variant, schema, span, context, index_name) + indexes[typegen.type_to_typeref(variant, ctx.env)] = idx else: idx = _validate_has_object_index( stype, schema, span, context, index_name) diff --git a/edb/edgeql/compiler/schemactx.py b/edb/edgeql/compiler/schemactx.py index a60cbaa3c91..7305937c4e4 100644 --- a/edb/edgeql/compiler/schemactx.py +++ b/edb/edgeql/compiler/schemactx.py @@ -518,6 +518,12 @@ def get_all_concrete( for t in union.objects(ctx.env.schema) for x in get_all_concrete(t, ctx=ctx) } + elif intersection := stype.get_intersection_of(ctx.env.schema): + return { + x + for t in intersection.objects(ctx.env.schema) + for x in get_all_concrete(t, ctx=ctx) + } return {stype} | { x for x in stype.descendants(ctx.env.schema) if x.is_material_object_type(ctx.env.schema) diff --git a/edb/edgeql/declarative.py b/edb/edgeql/declarative.py index fe691bfd6a9..9ac7b1ed5b5 100644 --- a/edb/edgeql/declarative.py +++ b/edb/edgeql/declarative.py @@ -1484,6 +1484,12 @@ def _resolve_type_expr( _resolve_type_expr(texpr.right, ctx=ctx), ]) + elif texpr.op == '&': + return qltracer.IntersectionType([ + _resolve_type_expr(texpr.left, ctx=ctx), + _resolve_type_expr(texpr.right, ctx=ctx), + ]) + else: raise NotImplementedError( f'unsupported type operation: {texpr.op}') diff --git a/edb/edgeql/tracer.py b/edb/edgeql/tracer.py index 5a6e99e6a75..49520b2c832 100644 --- a/edb/edgeql/tracer.py +++ b/edb/edgeql/tracer.py @@ -199,6 +199,22 @@ def is_object_type(self) -> bool: return True +class IntersectionType(Type): + + def __init__( + self, types: List[Union[Type, IntersectionType, so.Object]] + ) -> None: + self.types = types + + def get_name(self, schema: s_schema.Schema) -> sn.QualName: + component_ids = sorted(str(t.get_name(schema)) for t in self.types) + nqname = f"({' & '.join(component_ids)})" + return sn.QualName(name=nqname, module='__derived__') + + def is_object_type(self) -> bool: + return True + + class Pointer(Source): def __init__( @@ -941,6 +957,12 @@ def _resolve_type_expr( _resolve_type_expr(texpr.right, ctx=ctx), ]) + elif texpr.op == '&': + return IntersectionType([ + _resolve_type_expr(texpr.left, ctx=ctx), + _resolve_type_expr(texpr.right, ctx=ctx), + ]) + else: raise NotImplementedError( f'unsupported type operation: {texpr.op}') diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index bfb0ece9971..09cef483e6f 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -6206,6 +6206,8 @@ def get_target_objs(self, link, schema): tgt = link.get_target(schema) if union := tgt.get_union_of(schema).objects(schema): objs = set(union) + elif intersection := tgt.get_intersection_of(schema).objects(schema): + objs = set(intersection) else: objs = {tgt} objs |= { @@ -6803,9 +6805,10 @@ def apply( # triggers updated, so track them down. all_affected_targets = set() for target in affected_targets: - union_of = target.get_union_of(schema) - if union_of: + if union_of := target.get_union_of(schema): objtypes = tuple(union_of.objects(schema)) + elif intersection_of := target.get_intersection_of(schema): + objtypes = tuple(intersection_of.objects(schema)) else: objtypes = (target,) @@ -6836,6 +6839,10 @@ def apply( target, scls_type=s_objtypes.ObjectType, field_name='union_of' ), + schema.get_referrers( + target, scls_type=s_objtypes.ObjectType, + field_name='intersection_of' + ), ): inbound_links |= schema.get_referrers( ancestor, scls_type=s_links.Link, field_name='target') diff --git a/edb/schema/ddl.py b/edb/schema/ddl.py index efde6f94825..d735cf79d4c 100644 --- a/edb/schema/ddl.py +++ b/edb/schema/ddl.py @@ -402,7 +402,7 @@ def _only_generic( obj = cast(s_objtypes.ObjectType, relevant_schema.get(cmd.classname)) - if obj.is_union_type(relevant_schema): + if obj.is_compound_type(relevant_schema): continue result.add(cmd) diff --git a/edb/schema/objtypes.py b/edb/schema/objtypes.py index 6799a206e57..41529635b02 100644 --- a/edb/schema/objtypes.py +++ b/edb/schema/objtypes.py @@ -222,6 +222,13 @@ def getrptrs( for union in unions: ptrs.update(union.getrptrs(schema, name, sources=sources)) + intersections = schema.get_referrers( + self, scls_type=ObjectType, field_name='intersection_of' + ) + + for intersection in intersections: + ptrs.update(intersection.getrptrs(schema, name, sources=sources)) + return ptrs def get_relevant_triggers( @@ -683,6 +690,49 @@ def _alter_finalize( schema = diff.apply(schema, context) self.add(diff) + # Do the same for intersections + intersections = schema.get_referrers( + self.scls, scls_type=ObjectType, field_name='intersection_of') + + orig_disable = context.disable_dep_verification + + for intersection in intersections: + delete = ( + intersection.init_delta_command(schema, sd.DeleteObject) + ) + + context.disable_dep_verification = True + delete.apply(schema, context) + context.disable_dep_verification = orig_disable + # We run the delete to populate the tree, but then instead + # of actually deleting the object, we just remove the names. + # This is because the pointers in the types we are looking + # at might themselves reference the intersection, so we need + # them in the schema to produce the correct as_alter_delta. + nschema = _delete_to_delist(delete, schema) + + nschema, nintersection, _ = utils.ensure_intersection_type( + nschema, + types=( + intersection + .get_intersection_of(schema) + .objects(schema) + ), + module=intersection.get_name(schema).module, + ) + assert isinstance(nintersection, ObjectType) + + diff = intersection.as_alter_delta( + other=nintersection, + self_schema=schema, + other_schema=nschema, + confidence=1.0, + context=so.ComparisonContext(), + ) + + schema = diff.apply(schema, context) + self.add(diff) + return super()._alter_finalize(schema, context) diff --git a/edb/schema/pointers.py b/edb/schema/pointers.py index 97d74cdf04d..fc3e8bf0b60 100644 --- a/edb/schema/pointers.py +++ b/edb/schema/pointers.py @@ -33,6 +33,7 @@ import collections.abc import enum +import itertools import json import operator @@ -2289,12 +2290,18 @@ def _canonicalize( # Any union type that references this field needs to have it # deleted. - unions = schema.get_referrers( - self.scls, scls_type=Pointer, field_name='union_of') - for union in unions: - group, op, _ = union.init_delta_branch( + referrers = itertools.chain( + schema.get_referrers( + self.scls, scls_type=Pointer, field_name='union_of' + ), + schema.get_referrers( + self.scls, scls_type=Pointer, field_name='intersection_of' + ), + ) + for referrer in referrers: + group, op, _ = referrer.init_delta_branch( schema, context, sd.DeleteObject) - op.update(op._canonicalize(schema, context, union)) + op.update(op._canonicalize(schema, context, referrer)) commands.append(group) return commands diff --git a/edb/schema/utils.py b/edb/schema/utils.py index b7c0304bc93..7d2a9b237de 100644 --- a/edb/schema/utils.py +++ b/edb/schema/utils.py @@ -570,17 +570,25 @@ def typeref_to_ast( for st in t.get_subtypes(schema) ] ) - elif isinstance(t, s_types.Type) and t.is_union_type(schema): + elif isinstance(t, s_types.Type) and (t.is_compound_type(schema)): object_set = t.get_union_of(schema) assert object_set is not None component_objects = tuple(object_set.objects(schema)) result = typeref_to_ast(schema, component_objects[0], disambiguate_std=disambiguate_std) + + if t.is_union_type(schema): + op = '|' + elif t.is_intersection_type(schema): + op = '&' + else: + raise NotImplementedError + for component_object in component_objects[1:]: result = qlast.TypeOp( left=result, - op='|', + op=op, right=typeref_to_ast(schema, component_object, disambiguate_std=disambiguate_std), ) @@ -689,6 +697,15 @@ def shell_to_ast( op='|', right=typeref_to_ast(schema, component), ) + elif isinstance(t, s_types.IntersectionTypeShell): + components = t.get_components(schema) + result = typeref_to_ast(schema, components[0]) + for component in components[1:]: + result = qlast.TypeOp( + left=result, + op='&', + right=typeref_to_ast(schema, component), + ) elif isinstance(t, s_scalars.AnonymousEnumTypeShell): result = qlast.TypeName( name=_name, diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index f84603e15ec..96d3fccfa33 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1199,7 +1199,7 @@ def describe_database_dump( cfg_object = schema.get('cfg::ConfigObject', type=s_objtypes.ObjectType) for objtype in objtypes: - if objtype.is_union_type(schema) or objtype.is_view(schema): + if objtype.is_compound_type(schema) or objtype.is_view(schema): continue if objtype.issubclass(schema, cfg_object): continue diff --git a/edb/tools/experimental_interpreter/elaboration.py b/edb/tools/experimental_interpreter/elaboration.py index f217beba68e..f0684324633 100644 --- a/edb/tools/experimental_interpreter/elaboration.py +++ b/edb/tools/experimental_interpreter/elaboration.py @@ -27,6 +27,7 @@ FunAppExpr, IndirectionIndexOp, InsertExpr, + IntersectTp, IntVal, Label, LinkPropLabel, @@ -611,6 +612,11 @@ def elab_single_type_expr(typedef: qlast.TypeExpr) -> Tp: left=elab_single_type_expr(left_type), right=elab_single_type_expr(right_type), ) + elif op_name == "&": + return IntersectTp( + left=elab_single_type_expr(left_type), + right=elab_single_type_expr(right_type), + ) else: raise ValueError("Unknown Type Op") raise ValueError("MATCH")