diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index f486c2560dc..c1d866784a2 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -182,6 +182,10 @@ class Anchor(Expr): name: str +class IRAnchor(Anchor): + has_dml: bool = False + + class SpecialAnchor(Anchor): pass diff --git a/edb/edgeql/codegen.py b/edb/edgeql/codegen.py index 6b32720034b..e4cb5eb9833 100644 --- a/edb/edgeql/codegen.py +++ b/edb/edgeql/codegen.py @@ -781,6 +781,9 @@ def visit_ObjectRef(self, node: qlast.ObjectRef) -> None: def visit_Anchor(self, node: qlast.Anchor) -> None: self.write(node.name) + def visit_IRAnchor(self, node: qlast.IRAnchor) -> None: + self.write(node.name) + def visit_SpecialAnchor(self, node: qlast.SpecialAnchor) -> None: self.write(node.name) diff --git a/edb/edgeql/compiler/casts.py b/edb/edgeql/compiler/casts.py index 1bb39b54535..d92887b2f0f 100644 --- a/edb/edgeql/compiler/casts.py +++ b/edb/edgeql/compiler/casts.py @@ -74,7 +74,14 @@ def compile_cast( ctx=ctx, srcctx=ir_expr.context) - if irutils.is_untyped_empty_array_expr(ir_expr): + if isinstance(new_stype, s_types.Array) and ( + irutils.is_untyped_empty_array_expr(ir_expr) + or ( + isinstance(ir_expr, irast.Set) + and irutils.is_untyped_empty_array_expr( + irutils.unwrap_set(ir_expr).expr) + ) + ): # Ditto for empty arrays. new_typeref = typegen.type_to_typeref(new_stype, ctx.env) return setgen.ensure_set( @@ -97,6 +104,17 @@ def compile_cast( f'`...[IS {new_stype.get_displayname(ctx.env.schema)}]` instead', context=srcctx) + # The only valid object type cast other than is from anytype, + # and thus it must be an empty set. + if ( + orig_stype.is_any(ctx.env.schema) + and new_stype.is_object_type() + ): + return setgen.new_empty_set( + stype=new_stype, + ctx=ctx, + srcctx=ir_expr.context) + uuid_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'uuid')) if ( orig_stype.issubclass(ctx.env.schema, uuid_t) diff --git a/edb/edgeql/compiler/context.py b/edb/edgeql/compiler/context.py index 6c62b085109..dcacdb4fb39 100644 --- a/edb/edgeql/compiler/context.py +++ b/edb/edgeql/compiler/context.py @@ -37,6 +37,7 @@ from edb.edgeql import qltypes from edb.ir import ast as irast +from edb.ir import utils as irutils from edb.ir import typeutils as irtyputils from edb.schema import expraliases as s_aliases @@ -734,11 +735,16 @@ def newscope( def detached(self) -> compiler.CompilerContextManager[ContextLevel]: return self.new(ContextSwitchMode.DETACHED) - def create_anchor(self, ir: irast.Set, name: str='v') -> qlast.Path: + def create_anchor( + self, ir: irast.Set, name: str='v', *, check_dml: bool=False + ) -> qlast.Path: alias = self.aliases.get(name) + # TODO: We should probably always check for DML, but I'm + # concerned about perf, since we don't cache it at all. + has_dml = check_dml and irutils.contains_dml(ir) self.anchors[alias] = ir return qlast.Path( - steps=[qlast.ObjectRef(name=alias)], + steps=[qlast.IRAnchor(name=alias, has_dml=has_dml)], ) def maybe_create_anchor( diff --git a/edb/edgeql/compiler/expr.py b/edb/edgeql/compiler/expr.py index 7e8cf7ae22c..a0e601e4436 100644 --- a/edb/edgeql/compiler/expr.py +++ b/edb/edgeql/compiler/expr.py @@ -42,6 +42,7 @@ from edb.schema import utils as s_utils from edb.edgeql import ast as qlast +from edb.edgeql import utils from . import astutils from . import casts @@ -306,14 +307,107 @@ def compile_Array( return setgen.new_array_set(elements, ctx=ctx, srcctx=expr.context) +def _compile_dml_ifelse( + expr: qlast.IfElse, *, ctx: context.ContextLevel) -> irast.Set: + """Transform an IF/ELSE that contains DML into FOR loops + + The basic approach is to extract the pieces from the if/then/else and + rewrite them into: + for b in COND union ( + { + (for _ in (select () filter b) union (IF_BRANCH)), + (for _ in (select () filter not b) union (ELSE_BRANCH)), + } + ) + """ + + with ctx.newscope(fenced=False) as subctx: + # We have to compile it under a factoring fence to prevent + # correlation with outside things. We can't just rely on the + # factoring fences inserted when compiling the FORs, since we + # are going to need to explicitly exempt the iterator + # expression from that. + subctx.path_scope.factoring_fence = True + + ir = func.compile_operator( + expr, op_name='std::IF', + qlargs=[expr.if_expr, expr.condition, expr.else_expr], ctx=subctx) + + # Extract the IR parts from the IF/THEN/ELSE + # Note that cond_ir will be unfenced while if_ir and else_ir + # will have been compiled under fences. + match ir.expr: + case irast.OperatorCall(args=[ + irast.CallArg(expr=if_ir), + irast.CallArg(expr=cond_ir), + irast.CallArg(expr=else_ir), + ]): + pass + case _: + raise AssertionError('malformed DML IF/ELSE') + + subctx.anchors = subctx.anchors.copy() + + alias = ctx.aliases.get('b') + cond_path = qlast.Path( + steps=[qlast.ObjectRef(name=alias)], + ) + + els: list[qlast.Expr] = [] + + if not isinstance(irutils.unwrap_set(if_ir), irast.EmptySet): + if_b = qlast.ForQuery( + iterator_alias='__', + iterator=qlast.SelectQuery( + result=qlast.Tuple(elements=[]), + where=cond_path, + ), + result=subctx.create_anchor(if_ir, check_dml=True), + ) + els.append(if_b) + + if not isinstance(irutils.unwrap_set(else_ir), irast.EmptySet): + else_b = qlast.ForQuery( + iterator_alias='__', + iterator=qlast.SelectQuery( + result=qlast.Tuple(elements=[]), + where=qlast.UnaryOp(op='NOT', operand=cond_path), + ), + result=subctx.create_anchor(else_ir, check_dml=True), + ) + els.append(else_b) + + full = qlast.ForQuery( + iterator_alias=alias, + iterator=subctx.create_anchor(cond_ir, 'b'), + result=qlast.Set(elements=els) if len(els) != 1 else els[0], + ) + + subctx.iterator_path_ids |= {cond_ir.path_id} + res = dispatch.compile(full, ctx=subctx) + # Indicate that the original IF/ELSE code should determine the + # cardinality/multiplicity. + res.card_inference_override = ir + + return res + + @dispatch.compile.register(qlast.IfElse) def compile_IfElse( expr: qlast.IfElse, *, ctx: context.ContextLevel) -> irast.Set: - return func.compile_operator( + if ( + utils.contains_dml(expr.if_expr) + or utils.contains_dml(expr.else_expr) + ): + return _compile_dml_ifelse(expr, ctx=ctx) + + res = func.compile_operator( expr, op_name='std::IF', qlargs=[expr.if_expr, expr.condition, expr.else_expr], ctx=ctx) + return res + @dispatch.compile.register(qlast.UnaryOp) def compile_UnaryOp( @@ -401,7 +495,6 @@ def compile_GlobalExpr( def compile_TypeCast( expr: qlast.TypeCast, *, ctx: context.ContextLevel) -> irast.Set: target_stype = typegen.ql_typeexpr_to_type(expr.type, ctx=ctx) - target_typeref = typegen.type_to_typeref(target_stype, env=ctx.env) ir_expr: Union[irast.Set, irast.Expr] if isinstance(expr.expr, qlast.Parameter): @@ -492,14 +585,7 @@ def compile_TypeCast( subctx.implicit_tid_in_shapes = False subctx.implicit_tname_in_shapes = False - if ( - isinstance(expr.expr, qlast.Array) - and not expr.expr.elements - and irtyputils.is_array(target_typeref) - ): - ir_expr = irast.Array(elements=[], typeref=target_typeref) - else: - ir_expr = dispatch.compile(expr.expr, ctx=subctx) + ir_expr = dispatch.compile(expr.expr, ctx=subctx) res = casts.compile_cast( ir_expr, diff --git a/edb/edgeql/compiler/func.py b/edb/edgeql/compiler/func.py index a7076f070a6..d1bd238da09 100644 --- a/edb/edgeql/compiler/func.py +++ b/edb/edgeql/compiler/func.py @@ -341,7 +341,6 @@ def func(f: _SpecialCaseFunc) -> _SpecialCaseFunc: #: A dictionary of conditional callables and the indices #: of the arguments that are evaluated conditionally. CONDITIONAL_OPS = { - sn.QualName('std', 'IF'): {0, 2}, sn.QualName('std', '??'): {1}, } diff --git a/edb/edgeql/compiler/inference/cardinality.py b/edb/edgeql/compiler/inference/cardinality.py index f512e679b72..66d7d7ee3ce 100644 --- a/edb/edgeql/compiler/inference/cardinality.py +++ b/edb/edgeql/compiler/inference/cardinality.py @@ -519,6 +519,12 @@ def _infer_set( ir, is_mutation=is_mutation, scope_tree=scope_tree, ctx=ctx) + # But actually! Check if it is overridden + if ir.card_inference_override: + result = _infer_set_inner( + ir.card_inference_override, is_mutation=is_mutation, + scope_tree=scope_tree, ctx=ctx) + # We need to cache the main result before doing the shape, # since sometimes the shape will refer to the enclosing set. ctx.inferred_cardinality[ir] = result diff --git a/edb/edgeql/compiler/inference/multiplicity.py b/edb/edgeql/compiler/inference/multiplicity.py index e4b0b2d5387..8bce18acd5d 100644 --- a/edb/edgeql/compiler/inference/multiplicity.py +++ b/edb/edgeql/compiler/inference/multiplicity.py @@ -226,6 +226,11 @@ def _infer_set( result = _infer_set_inner( ir, is_mutation=is_mutation, scope_tree=scope_tree, ctx=ctx ) + if ir.card_inference_override: + result = _infer_set_inner( + ir.card_inference_override, is_mutation=is_mutation, + scope_tree=scope_tree, ctx=ctx) + ctx.inferred_multiplicity[ir, scope_tree, ctx.distinct_iterator] = result # The shape doesn't affect multiplicity, but requires validation. diff --git a/edb/edgeql/compiler/polyres.py b/edb/edgeql/compiler/polyres.py index a79520e1b9f..be7270ecb6b 100644 --- a/edb/edgeql/compiler/polyres.py +++ b/edb/edgeql/compiler/polyres.py @@ -262,6 +262,18 @@ def _get_cast_distance( # refine our resolved_poly_base_type to be that as the # more general case. resolved_poly_base_type = ct + else: + # Try resolving a polymorphic argument type against the + # resolved base type. This lets us handle cases like + # - if b then x else {} + # - if b then [1] else [] + # Though it is still unfortunately not smart enough + # to handle the reverse case. + if resolved.is_polymorphic(schema): + ct = resolved.resolve_polymorphic( + schema, resolved_poly_base_type) + + if ct is not None: return s_types.MAX_TYPE_DISTANCE if is_abstract else 0 else: return -1 diff --git a/edb/edgeql/compiler/setgen.py b/edb/edgeql/compiler/setgen.py index 3a786ef555c..332e7945674 100644 --- a/edb/edgeql/compiler/setgen.py +++ b/edb/edgeql/compiler/setgen.py @@ -286,6 +286,13 @@ def compile_path(expr: qlast.Path, *, ctx: context.ContextLevel) -> irast.Set: if isinstance(step, qlast.SpecialAnchor): path_tip = resolve_special_anchor(step, ctx=ctx) + elif isinstance(step, qlast.IRAnchor): + # Check if the starting path label is a known anchor + refnode = anchors.get(step.name) + if not refnode: + raise AssertionError(f'anchor {step.name} is missing') + path_tip = new_set_from_set(refnode, ctx=ctx) + elif isinstance(step, qlast.ObjectRef): if i > 0: # pragma: no cover raise RuntimeError( diff --git a/edb/edgeql/tracer.py b/edb/edgeql/tracer.py index cfa4dba154c..352c663dfba 100644 --- a/edb/edgeql/tracer.py +++ b/edb/edgeql/tracer.py @@ -882,8 +882,8 @@ def trace_Path( @trace.register -def trace_SpecialAnchor( - node: qlast.SpecialAnchor, *, ctx: TracerContext +def trace_Anchor( + node: qlast.Anchor, *, ctx: TracerContext ) -> Optional[ObjectLike]: if name := ctx.anchors.get(node.name): return ctx.objects[name] diff --git a/edb/edgeql/utils.py b/edb/edgeql/utils.py index f92a16867e0..47bebe006b9 100644 --- a/edb/edgeql/utils.py +++ b/edb/edgeql/utils.py @@ -204,8 +204,13 @@ def contains_dml(ql_expr: qlast.Base) -> bool: if isinstance(ql_expr, dml_types): return True - res = ast.find_children(ql_expr, qlast.Query, - lambda x: isinstance(x, dml_types), - terminate_early=True) + res = ast.find_children( + ql_expr, qlast.Base, + lambda x: ( + isinstance(x, dml_types) + or (isinstance(x, qlast.IRAnchor) and x.has_dml) + ), + terminate_early=True, + ) return bool(res) diff --git a/edb/ir/ast.py b/edb/ir/ast.py index 5bb54ce6bf1..97ebd9050f7 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -533,6 +533,12 @@ class Set(Base): # insertions to BaseObject. ignore_rewrites: bool = False + # An expression to use instead of this one for the purpose of + # cardinality/multiplicity inference. This is used for when something + # is desugared in a way that doesn't preserve cardinality, but we + # need to anyway. + card_inference_override: typing.Optional[Set] = None + def __repr__(self) -> str: return f'' diff --git a/edb/ir/utils.py b/edb/ir/utils.py index 1222348cc0d..bf78e7e76f1 100644 --- a/edb/ir/utils.py +++ b/edb/ir/utils.py @@ -76,7 +76,7 @@ def is_union_expr(ir: irast.Base) -> bool: ) -def is_empty_array_expr(ir: irast.Base) -> bool: +def is_empty_array_expr(ir: Optional[irast.Base]) -> TypeGuard[irast.Array]: """Return True if the given *ir* expression is an empty array expression. """ return ( @@ -85,14 +85,16 @@ def is_empty_array_expr(ir: irast.Base) -> bool: ) -def is_untyped_empty_array_expr(ir: irast.Base) -> bool: +def is_untyped_empty_array_expr( + ir: Optional[irast.Base] +) -> TypeGuard[irast.Array]: """Return True if the given *ir* expression is an empty array expression of an uknown type. """ return ( is_empty_array_expr(ir) - and (ir.typeref is None # type: ignore - or typeutils.is_generic(ir.typeref)) # type: ignore + and (ir.typeref is None + or typeutils.is_generic(ir.typeref)) ) diff --git a/tests/test_edgeql_delete.py b/tests/test_edgeql_delete.py index bd473f1c607..dc3e9ede4e9 100644 --- a/tests/test_edgeql_delete.py +++ b/tests/test_edgeql_delete.py @@ -491,21 +491,6 @@ async def test_edgeql_delete_in_conditional_bad_01(self): (DELETE DeleteTest2); ''') - async def test_edgeql_delete_in_conditional_bad_02(self): - with self.assertRaisesRegex( - edgedb.QueryError, - 'DELETE statements cannot be used'): - await self.con.execute(r''' - SELECT - (SELECT DeleteTest FILTER .name = 'foo') - IF EXISTS DeleteTest - ELSE ( - (SELECT DeleteTest) - UNION - (DELETE DeleteTest) - ); - ''') - async def test_edgeql_delete_abstract_01(self): await self.con.execute(r""" diff --git a/tests/test_edgeql_expressions.py b/tests/test_edgeql_expressions.py index f540021221b..2c7fb86e3c6 100644 --- a/tests/test_edgeql_expressions.py +++ b/tests/test_edgeql_expressions.py @@ -3500,14 +3500,12 @@ async def test_edgeql_expr_array_20(self): """) async def test_edgeql_expr_array_21(self): - # it should be technically possible to infer the type of the array - with self.assertRaisesRegex( - edgedb.QueryError, - r"operator 'UNION' cannot be applied to operands.*anytype.*"): - - await self.con.execute(""" + await self.assert_query_result( + """ SELECT [1, 2] UNION []; - """) + """, + [[1, 2], []], + ) async def test_edgeql_expr_array_22(self): await self.assert_query_result( @@ -8257,6 +8255,43 @@ async def test_edgeql_expr_if_else_09(self): [[]], ) + async def test_edgeql_expr_if_else_10(self): + await self.assert_query_result( + r""" + select if true then 10 else {} + """, + [10], + ) + + await self.assert_query_result( + r""" + select if false then 10 else {} + """, + [], + ) + + await self.assert_query_result( + r""" + select if true then [10] else [] + """, + [[10]], + ) + + await self.assert_query_result( + r""" + select if false then [10] else [] + """, + [[]], + ) + + await self.assert_query_result( + r""" + with test := [''] + select test if false else []; + """, + [[]], + ) + async def test_edgeql_expr_setop_01(self): await self.assert_query_result( r"""SELECT EXISTS {};""", diff --git a/tests/test_edgeql_insert.py b/tests/test_edgeql_insert.py index e58d2fc8f5e..a16486f2a54 100644 --- a/tests/test_edgeql_insert.py +++ b/tests/test_edgeql_insert.py @@ -2532,22 +2532,6 @@ async def test_edgeql_insert_in_conditional_bad_01(self): (INSERT Subordinate { name := 'no way' }); ''') - async def test_edgeql_insert_in_conditional_bad_02(self): - with self.assertRaisesRegex( - edgedb.QueryError, - 'INSERT statements cannot be used inside ' - 'conditional expressions'): - await self.con.execute(r''' - SELECT - (SELECT Subordinate FILTER .name = 'foo') - IF EXISTS Subordinate - ELSE ( - (SELECT Subordinate) - UNION - (INSERT Subordinate { name := 'no way' }) - ); - ''') - async def test_edgeql_insert_correlated_bad_01(self): with self.assertRaisesRegex( edgedb.QueryError, @@ -6010,3 +5994,96 @@ async def test_edgeql_insert_single_linkprop(self): ''', [{"sub": {"name": str, "@note": "!"}}] * 10, ) + + async def test_edgeql_insert_conditional_01(self): + await self.assert_query_result( + ''' + select if $0 then ( + insert InsertTest { l2 := 2 } + ) else ( + insert DerivedTest { l2 := 200 } + ) + ''', + [{}], + variables=(True,) + ) + + await self.assert_query_result( + ''' + select InsertTest { l2, tname := .__type__.name } + ''', + [ + {"l2": 2, "tname": "default::InsertTest"}, + ], + ) + + await self.assert_query_result( + ''' + select if $0 then ( + insert InsertTest { l2 := 2 } + ) else ( + insert DerivedTest { l2 := 200 } + ) + ''', + [{}], + variables=(False,) + ) + + await self.assert_query_result( + ''' + select InsertTest { l2, tname := .__type__.name } order by .l2 + ''', + [ + {"l2": 2, "tname": "default::InsertTest"}, + {"l2": 200, "tname": "default::DerivedTest"}, + ], + ) + + await self.assert_query_result( + ''' + select if array_unpack(>$0) then ( + insert InsertTest { l2 := 2 } + ) else ( + insert DerivedTest { l2 := 200 } + ) + ''', + [{}, {}], + variables=([True, False],) + ) + + await self.assert_query_result( + ''' + with go := $0 + select if go then ( + insert InsertTest { l2 := 100 } + ) else {} + ''', + [{}], + variables=(True,) + ) + + await self.assert_query_result( + ''' + select InsertTest { l2, tname := .__type__.name } order by .l2 + ''', + [ + {"l2": 2, "tname": "default::InsertTest"}, + {"l2": 2, "tname": "default::InsertTest"}, + {"l2": 100, "tname": "default::InsertTest"}, + {"l2": 200, "tname": "default::DerivedTest"}, + {"l2": 200, "tname": "default::DerivedTest"}, + ], + ) + + async def test_edgeql_insert_conditional_02(self): + async with self.assertRaisesRegexTx( + edgedb.errors.QueryError, + "cannot reference correlated set", + ): + await self.con.execute(''' + select ((if ExceptTest.deleted then ( + insert InsertTest { l2 := 2 } + ) else ( + insert DerivedTest { l2 := 200 } + )), (select ExceptTest.deleted limit 1)); + ''') diff --git a/tests/test_edgeql_ir_card_inference.py b/tests/test_edgeql_ir_card_inference.py index 33b5ab87069..f8846d24a72 100644 --- a/tests/test_edgeql_ir_card_inference.py +++ b/tests/test_edgeql_ir_card_inference.py @@ -1175,3 +1175,33 @@ def test_edgeql_ir_card_inference_138(self): % OK % AT_LEAST_ONE """ + + def test_edgeql_ir_card_inference_139(self): + """ + if $0 then + (insert User { name := "test" }) + else + (insert User { name := "???" }) +% OK % + ONE + """ + + def test_edgeql_ir_card_inference_140(self): + """ + if $0 then + (insert User { name := "test" }) + else + {(insert User { name := "???" }), (insert User { name := "!!!" })} +% OK % + AT_LEAST_ONE + """ + + def test_edgeql_ir_card_inference_141(self): + """ + if $0 then + (insert User { name := "test" }) + else + {} +% OK % + AT_MOST_ONE + """ diff --git a/tests/test_edgeql_ir_mult_inference.py b/tests/test_edgeql_ir_mult_inference.py index 1ede2a70115..6e194e3bff9 100644 --- a/tests/test_edgeql_ir_mult_inference.py +++ b/tests/test_edgeql_ir_mult_inference.py @@ -835,3 +835,33 @@ def test_edgeql_ir_mult_inference_89(self): % OK % DUPLICATE """ + + def test_edgeql_ir_mult_inference_90(self): + """ + if $0 then + (insert User { name := "test" }) + else + (insert User { name := "???" }) +% OK % + UNIQUE + """ + + def test_edgeql_ir_mult_inference_91(self): + """ + if $0 then + (insert User { name := "test" }) + else + {(insert User { name := "???" }), (insert User { name := "!!!" })} +% OK % + UNIQUE + """ + + def test_edgeql_ir_mult_inference_92(self): + """ + if $0 then + (insert User { name := "test" }) + else + {} +% OK % + UNIQUE + """ diff --git a/tests/test_edgeql_update.py b/tests/test_edgeql_update.py index 29bcb0aa1a4..b206f0b6c4a 100644 --- a/tests/test_edgeql_update.py +++ b/tests/test_edgeql_update.py @@ -2171,21 +2171,6 @@ async def test_edgeql_update_in_conditional_bad_01(self): (UPDATE UpdateTest SET { name := 'no way' }); ''') - async def test_edgeql_update_in_conditional_bad_02(self): - with self.assertRaisesRegex( - edgedb.QueryError, - 'UPDATE statements cannot be used'): - await self.con.execute(r''' - SELECT - (SELECT UpdateTest FILTER .name = 'foo') - IF EXISTS UpdateTest - ELSE ( - (SELECT UpdateTest) - UNION - (UPDATE UpdateTest SET { name := 'no way' }) - ); - ''') - async def test_edgeql_update_correlated_bad_01(self): with self.assertRaisesRegex( edgedb.QueryError,