Skip to content

Commit

Permalink
Support DML in IF/THEN/ELSE expressions (#6181)
Browse files Browse the repository at this point in the history
As we've discussed (and, tragically, recommended to users), we can
rewrite the conditionals into for loops:
    for b in COND union (
      {
	(for _ in (select () filter b) union (IF_BRANCH)),
	(for _ in (select () filter not b) union (ELSE_BRANCH)),
      }
    )

The main fiddly part is preserving the correct
cardinality/multiplicity inference. I did this by adding a
card_inference_override field to Set that specifies another set that
should determine the cardinality.

I don't love this, but I haven't thought of a cleaner approach that
doesn't give up the benefits of the desugaring approach.

We need more testing but I wanted to get something up for people to
look at / we can catch up on testing after the feature freeze if
needed.

Fixes #4437.

* Support writing a bare {} in the else branch

Because of how casts are inserted in the func code, this required some
tweaking to casts:
 - The empty array cast needs to be able to look through a Set in
   order to be efficient
 - The empty set to object cast needs to be able to look through a
   Set in order to not generate an IR cast that messes things up
   because it doesn't provide a source and thus causes the #3030
   overlay issue to pop up.
  • Loading branch information
msullivan committed Sep 28, 2023
1 parent b6745fd commit ea947dd
Show file tree
Hide file tree
Showing 20 changed files with 377 additions and 76 deletions.
4 changes: 4 additions & 0 deletions edb/edgeql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ class Anchor(Expr):
name: str


class IRAnchor(Anchor):
has_dml: bool = False


class SpecialAnchor(Anchor):
pass

Expand Down
3 changes: 3 additions & 0 deletions edb/edgeql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,9 @@ def visit_ObjectRef(self, node: qlast.ObjectRef) -> None:
def visit_Anchor(self, node: qlast.Anchor) -> None:
self.write(node.name)

def visit_IRAnchor(self, node: qlast.IRAnchor) -> None:
self.write(node.name)

def visit_SpecialAnchor(self, node: qlast.SpecialAnchor) -> None:
self.write(node.name)

Expand Down
20 changes: 19 additions & 1 deletion edb/edgeql/compiler/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ def compile_cast(
ctx=ctx,
srcctx=ir_expr.context)

if irutils.is_untyped_empty_array_expr(ir_expr):
if isinstance(new_stype, s_types.Array) and (
irutils.is_untyped_empty_array_expr(ir_expr)
or (
isinstance(ir_expr, irast.Set)
and irutils.is_untyped_empty_array_expr(
irutils.unwrap_set(ir_expr).expr)
)
):
# Ditto for empty arrays.
new_typeref = typegen.type_to_typeref(new_stype, ctx.env)
return setgen.ensure_set(
Expand All @@ -97,6 +104,17 @@ def compile_cast(
f'`...[IS {new_stype.get_displayname(ctx.env.schema)}]` instead',
context=srcctx)

# The only valid object type cast other than <uuid> is from anytype,
# and thus it must be an empty set.
if (
orig_stype.is_any(ctx.env.schema)
and new_stype.is_object_type()
):
return setgen.new_empty_set(
stype=new_stype,
ctx=ctx,
srcctx=ir_expr.context)

uuid_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'uuid'))
if (
orig_stype.issubclass(ctx.env.schema, uuid_t)
Expand Down
10 changes: 8 additions & 2 deletions edb/edgeql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from edb.edgeql import qltypes

from edb.ir import ast as irast
from edb.ir import utils as irutils
from edb.ir import typeutils as irtyputils

from edb.schema import expraliases as s_aliases
Expand Down Expand Up @@ -734,11 +735,16 @@ def newscope(
def detached(self) -> compiler.CompilerContextManager[ContextLevel]:
return self.new(ContextSwitchMode.DETACHED)

def create_anchor(self, ir: irast.Set, name: str='v') -> qlast.Path:
def create_anchor(
self, ir: irast.Set, name: str='v', *, check_dml: bool=False
) -> qlast.Path:
alias = self.aliases.get(name)
# TODO: We should probably always check for DML, but I'm
# concerned about perf, since we don't cache it at all.
has_dml = check_dml and irutils.contains_dml(ir)
self.anchors[alias] = ir
return qlast.Path(
steps=[qlast.ObjectRef(name=alias)],
steps=[qlast.IRAnchor(name=alias, has_dml=has_dml)],
)

def maybe_create_anchor(
Expand Down
106 changes: 96 additions & 10 deletions edb/edgeql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast
from edb.edgeql import utils

from . import astutils
from . import casts
Expand Down Expand Up @@ -306,14 +307,107 @@ def compile_Array(
return setgen.new_array_set(elements, ctx=ctx, srcctx=expr.context)


def _compile_dml_ifelse(
expr: qlast.IfElse, *, ctx: context.ContextLevel) -> irast.Set:
"""Transform an IF/ELSE that contains DML into FOR loops
The basic approach is to extract the pieces from the if/then/else and
rewrite them into:
for b in COND union (
{
(for _ in (select () filter b) union (IF_BRANCH)),
(for _ in (select () filter not b) union (ELSE_BRANCH)),
}
)
"""

with ctx.newscope(fenced=False) as subctx:
# We have to compile it under a factoring fence to prevent
# correlation with outside things. We can't just rely on the
# factoring fences inserted when compiling the FORs, since we
# are going to need to explicitly exempt the iterator
# expression from that.
subctx.path_scope.factoring_fence = True

ir = func.compile_operator(
expr, op_name='std::IF',
qlargs=[expr.if_expr, expr.condition, expr.else_expr], ctx=subctx)

# Extract the IR parts from the IF/THEN/ELSE
# Note that cond_ir will be unfenced while if_ir and else_ir
# will have been compiled under fences.
match ir.expr:
case irast.OperatorCall(args=[
irast.CallArg(expr=if_ir),
irast.CallArg(expr=cond_ir),
irast.CallArg(expr=else_ir),
]):
pass
case _:
raise AssertionError('malformed DML IF/ELSE')

subctx.anchors = subctx.anchors.copy()

alias = ctx.aliases.get('b')
cond_path = qlast.Path(
steps=[qlast.ObjectRef(name=alias)],
)

els: list[qlast.Expr] = []

if not isinstance(irutils.unwrap_set(if_ir), irast.EmptySet):
if_b = qlast.ForQuery(
iterator_alias='__',
iterator=qlast.SelectQuery(
result=qlast.Tuple(elements=[]),
where=cond_path,
),
result=subctx.create_anchor(if_ir, check_dml=True),
)
els.append(if_b)

if not isinstance(irutils.unwrap_set(else_ir), irast.EmptySet):
else_b = qlast.ForQuery(
iterator_alias='__',
iterator=qlast.SelectQuery(
result=qlast.Tuple(elements=[]),
where=qlast.UnaryOp(op='NOT', operand=cond_path),
),
result=subctx.create_anchor(else_ir, check_dml=True),
)
els.append(else_b)

full = qlast.ForQuery(
iterator_alias=alias,
iterator=subctx.create_anchor(cond_ir, 'b'),
result=qlast.Set(elements=els) if len(els) != 1 else els[0],
)

subctx.iterator_path_ids |= {cond_ir.path_id}
res = dispatch.compile(full, ctx=subctx)
# Indicate that the original IF/ELSE code should determine the
# cardinality/multiplicity.
res.card_inference_override = ir

return res


@dispatch.compile.register(qlast.IfElse)
def compile_IfElse(
expr: qlast.IfElse, *, ctx: context.ContextLevel) -> irast.Set:

return func.compile_operator(
if (
utils.contains_dml(expr.if_expr)
or utils.contains_dml(expr.else_expr)
):
return _compile_dml_ifelse(expr, ctx=ctx)

res = func.compile_operator(
expr, op_name='std::IF',
qlargs=[expr.if_expr, expr.condition, expr.else_expr], ctx=ctx)

return res


@dispatch.compile.register(qlast.UnaryOp)
def compile_UnaryOp(
Expand Down Expand Up @@ -401,7 +495,6 @@ def compile_GlobalExpr(
def compile_TypeCast(
expr: qlast.TypeCast, *, ctx: context.ContextLevel) -> irast.Set:
target_stype = typegen.ql_typeexpr_to_type(expr.type, ctx=ctx)
target_typeref = typegen.type_to_typeref(target_stype, env=ctx.env)
ir_expr: Union[irast.Set, irast.Expr]

if isinstance(expr.expr, qlast.Parameter):
Expand Down Expand Up @@ -492,14 +585,7 @@ def compile_TypeCast(
subctx.implicit_tid_in_shapes = False
subctx.implicit_tname_in_shapes = False

if (
isinstance(expr.expr, qlast.Array)
and not expr.expr.elements
and irtyputils.is_array(target_typeref)
):
ir_expr = irast.Array(elements=[], typeref=target_typeref)
else:
ir_expr = dispatch.compile(expr.expr, ctx=subctx)
ir_expr = dispatch.compile(expr.expr, ctx=subctx)

res = casts.compile_cast(
ir_expr,
Expand Down
1 change: 0 additions & 1 deletion edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ def func(f: _SpecialCaseFunc) -> _SpecialCaseFunc:
#: A dictionary of conditional callables and the indices
#: of the arguments that are evaluated conditionally.
CONDITIONAL_OPS = {
sn.QualName('std', 'IF'): {0, 2},
sn.QualName('std', '??'): {1},
}

Expand Down
6 changes: 6 additions & 0 deletions edb/edgeql/compiler/inference/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,12 @@ def _infer_set(
ir, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

# But actually! Check if it is overridden
if ir.card_inference_override:
result = _infer_set_inner(
ir.card_inference_override, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

# We need to cache the main result before doing the shape,
# since sometimes the shape will refer to the enclosing set.
ctx.inferred_cardinality[ir] = result
Expand Down
5 changes: 5 additions & 0 deletions edb/edgeql/compiler/inference/multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ def _infer_set(
result = _infer_set_inner(
ir, is_mutation=is_mutation, scope_tree=scope_tree, ctx=ctx
)
if ir.card_inference_override:
result = _infer_set_inner(
ir.card_inference_override, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

ctx.inferred_multiplicity[ir, scope_tree, ctx.distinct_iterator] = result

# The shape doesn't affect multiplicity, but requires validation.
Expand Down
12 changes: 12 additions & 0 deletions edb/edgeql/compiler/polyres.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,18 @@ def _get_cast_distance(
# refine our resolved_poly_base_type to be that as the
# more general case.
resolved_poly_base_type = ct
else:
# Try resolving a polymorphic argument type against the
# resolved base type. This lets us handle cases like
# - if b then x else {}
# - if b then [1] else []
# Though it is still unfortunately not smart enough
# to handle the reverse case.
if resolved.is_polymorphic(schema):
ct = resolved.resolve_polymorphic(
schema, resolved_poly_base_type)

if ct is not None:
return s_types.MAX_TYPE_DISTANCE if is_abstract else 0
else:
return -1
Expand Down
7 changes: 7 additions & 0 deletions edb/edgeql/compiler/setgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,13 @@ def compile_path(expr: qlast.Path, *, ctx: context.ContextLevel) -> irast.Set:
if isinstance(step, qlast.SpecialAnchor):
path_tip = resolve_special_anchor(step, ctx=ctx)

elif isinstance(step, qlast.IRAnchor):
# Check if the starting path label is a known anchor
refnode = anchors.get(step.name)
if not refnode:
raise AssertionError(f'anchor {step.name} is missing')
path_tip = new_set_from_set(refnode, ctx=ctx)

elif isinstance(step, qlast.ObjectRef):
if i > 0: # pragma: no cover
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions edb/edgeql/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,8 @@ def trace_Path(


@trace.register
def trace_SpecialAnchor(
node: qlast.SpecialAnchor, *, ctx: TracerContext
def trace_Anchor(
node: qlast.Anchor, *, ctx: TracerContext
) -> Optional[ObjectLike]:
if name := ctx.anchors.get(node.name):
return ctx.objects[name]
Expand Down
11 changes: 8 additions & 3 deletions edb/edgeql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,13 @@ def contains_dml(ql_expr: qlast.Base) -> bool:
if isinstance(ql_expr, dml_types):
return True

res = ast.find_children(ql_expr, qlast.Query,
lambda x: isinstance(x, dml_types),
terminate_early=True)
res = ast.find_children(
ql_expr, qlast.Base,
lambda x: (
isinstance(x, dml_types)
or (isinstance(x, qlast.IRAnchor) and x.has_dml)
),
terminate_early=True,
)

return bool(res)
6 changes: 6 additions & 0 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,12 @@ class Set(Base):
# insertions to BaseObject.
ignore_rewrites: bool = False

# An expression to use instead of this one for the purpose of
# cardinality/multiplicity inference. This is used for when something
# is desugared in a way that doesn't preserve cardinality, but we
# need to anyway.
card_inference_override: typing.Optional[Set] = None

def __repr__(self) -> str:
return f'<ir.Set \'{self.path_id}\' at 0x{id(self):x}>'

Expand Down
10 changes: 6 additions & 4 deletions edb/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def is_union_expr(ir: irast.Base) -> bool:
)


def is_empty_array_expr(ir: irast.Base) -> bool:
def is_empty_array_expr(ir: Optional[irast.Base]) -> TypeGuard[irast.Array]:
"""Return True if the given *ir* expression is an empty array expression.
"""
return (
Expand All @@ -85,14 +85,16 @@ def is_empty_array_expr(ir: irast.Base) -> bool:
)


def is_untyped_empty_array_expr(ir: irast.Base) -> bool:
def is_untyped_empty_array_expr(
ir: Optional[irast.Base]
) -> TypeGuard[irast.Array]:
"""Return True if the given *ir* expression is an empty
array expression of an uknown type.
"""
return (
is_empty_array_expr(ir)
and (ir.typeref is None # type: ignore
or typeutils.is_generic(ir.typeref)) # type: ignore
and (ir.typeref is None
or typeutils.is_generic(ir.typeref))
)


Expand Down
15 changes: 0 additions & 15 deletions tests/test_edgeql_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,21 +491,6 @@ async def test_edgeql_delete_in_conditional_bad_01(self):
(DELETE DeleteTest2);
''')

async def test_edgeql_delete_in_conditional_bad_02(self):
with self.assertRaisesRegex(
edgedb.QueryError,
'DELETE statements cannot be used'):
await self.con.execute(r'''
SELECT
(SELECT DeleteTest FILTER .name = 'foo')
IF EXISTS DeleteTest
ELSE (
(SELECT DeleteTest)
UNION
(DELETE DeleteTest)
);
''')

async def test_edgeql_delete_abstract_01(self):
await self.con.execute(r"""
Expand Down
Loading

0 comments on commit ea947dd

Please sign in to comment.