Skip to content

Commit

Permalink
Fix some bugs involving union and coalescing of optional values (#6590)
Browse files Browse the repository at this point in the history
There are a number of places that improperly handle a set that has
multiple NULL values in it:
 * INSERT on single values produced by ??
 * UPDATE on single values produced by ??
 * assert_single
 * limit/offset

The assert_single and limit case is a long standing bug that can be triggered
without much trouble. The INSERT/UPDATE bugs are new, and stem from
the new ?? on DML, since I think it that was the first thing in the
compiler backend where extra NULLS might be produced for something
with single cardinality.

Fix the long-standing cases by adding null checks, but fix the
coalesce cases by adding a null check at the producer side.
The rule then, is that there can be NULLs in any produced set
that can be empty, but a single set has to return at most one row
(so it can't have extra NULLs).


Fixes #6438.
  • Loading branch information
msullivan authored and aljazerzen committed Dec 21, 2023
1 parent 3233842 commit 54e33cd
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 33 deletions.
6 changes: 4 additions & 2 deletions edb/edgeql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ def _compile_dml_coalesce(
res = dispatch.compile(full, ctx=subctx)
# Indicate that the original ?? code should determine the
# cardinality/multiplicity.
res.card_inference_override = ir
assert isinstance(res.expr, irast.SelectStmt)
res.expr.card_inference_override = ir

return res

Expand Down Expand Up @@ -481,7 +482,8 @@ def _compile_dml_ifelse(
res = dispatch.compile(full, ctx=subctx)
# Indicate that the original IF/ELSE code should determine the
# cardinality/multiplicity.
res.card_inference_override = ir
assert isinstance(res.expr, irast.SelectStmt)
res.expr.card_inference_override = ir

return res

Expand Down
11 changes: 5 additions & 6 deletions edb/edgeql/compiler/inference/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,12 +519,6 @@ def _infer_set(
ir, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

# But actually! Check if it is overridden
if ir.card_inference_override:
result = _infer_set_inner(
ir.card_inference_override, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

# We need to cache the main result before doing the shape,
# since sometimes the shape will refer to the enclosing set.
ctx.inferred_cardinality[ir] = result
Expand Down Expand Up @@ -1273,6 +1267,11 @@ def __infer_select_stmt(
if ir.iterator_stmt:
stmt_card = cartesian_cardinality((stmt_card, iter_card))

# But actually! Check if it is overridden
if ir.card_inference_override:
stmt_card = infer_cardinality(
ir.card_inference_override, scope_tree=scope_tree, ctx=ctx)

return stmt_card


Expand Down
13 changes: 6 additions & 7 deletions edb/edgeql/compiler/inference/multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,6 @@ def _infer_set(
result = _infer_set_inner(
ir, is_mutation=is_mutation, scope_tree=scope_tree, ctx=ctx
)
if ir.card_inference_override:
result = _infer_set_inner(
ir.card_inference_override, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

ctx.inferred_multiplicity[ir, scope_tree, ctx.distinct_iterator] = result

# The shape doesn't affect multiplicity, but requires validation.
Expand Down Expand Up @@ -625,7 +620,7 @@ def __infer_select_stmt(
) -> inf_ctx.MultiplicityInfo:

if ir.iterator_stmt is not None:
return _infer_for_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)
stmt_mult = _infer_for_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)
else:
stmt_mult = _infer_stmt_multiplicity(
ir, scope_tree=scope_tree, ctx=ctx)
Expand All @@ -639,7 +634,11 @@ def __infer_select_stmt(
new_scope = inf_utils.get_set_scope(clause, scope_tree, ctx=ctx)
infer_multiplicity(clause, scope_tree=new_scope, ctx=ctx)

return stmt_mult
if ir.card_inference_override:
stmt_mult = infer_multiplicity(
ir.card_inference_override, scope_tree=scope_tree, ctx=ctx)

return stmt_mult


@_infer_multiplicity.register
Expand Down
12 changes: 6 additions & 6 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,6 @@ class Set(Base):
# insertions to BaseObject.
ignore_rewrites: bool = False

# An expression to use instead of this one for the purpose of
# cardinality/multiplicity inference. This is used for when something
# is desugared in a way that doesn't preserve cardinality, but we
# need to anyway.
card_inference_override: typing.Optional[Set] = None

def __repr__(self) -> str:
return f'<ir.Set \'{self.path_id}\' at 0x{id(self):x}>'

Expand Down Expand Up @@ -1077,6 +1071,12 @@ class SelectStmt(FilteredStmt):
limit: typing.Optional[Set] = None
implicit_wrapper: bool = False

# An expression to use instead of this one for the purpose of
# cardinality/multiplicity inference. This is used for when something
# is desugared in a way that doesn't preserve cardinality, but we
# need to anyway.
card_inference_override: typing.Optional[Set] = None


class GroupStmt(FilteredStmt):
subject: Set = EmptySet() # type: ignore
Expand Down
1 change: 1 addition & 0 deletions edb/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def is_trivial_select(ir_expr: irast.Base) -> TypeGuard[irast.SelectStmt]:
and ir_expr.where is None
and ir_expr.limit is None
and ir_expr.offset is None
and ir_expr.card_inference_override is None
)


Expand Down
13 changes: 13 additions & 0 deletions edb/pgsql/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,19 @@ def output_as_value(
return val


def add_null_test(expr: pgast.BaseExpr, query: pgast.SelectStmt) -> None:
if not expr.nullable:
return

while isinstance(expr, pgast.TupleVar) and expr.elements:
expr = expr.elements[0].val

query.where_clause = astutils.extend_binop(
query.where_clause,
pgast.NullTest(arg=expr, negated=True)
)


def serialize_expr_if_needed(
expr: pgast.BaseExpr, *,
path_id: irast.PathId,
Expand Down
2 changes: 2 additions & 0 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,6 +2363,8 @@ def process_set_as_singleton_assertion(
],
)

output.add_null_test(arg_ref, newctx.rel)

# Force Postgres to actually evaluate the result target
# by putting it into an ORDER BY.
newctx.rel.target_list.append(
Expand Down
26 changes: 14 additions & 12 deletions edb/pgsql/compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import dispatch
from . import group
from . import dml
from . import output
from . import pathctx


Expand Down Expand Up @@ -111,19 +112,20 @@ def compile_SelectStmt(
query.sort_clause = clauses.compile_orderby_clause(
stmt.orderby, ctx=octx)

if outvar.nullable and query is ctx.toplevel_stmt:
# A nullable var has bubbled up to the top,
# filter out NULLs.
valvar: pgast.BaseExpr = pathctx.get_path_value_var(
# Need to filter out NULLs in certain cases:
if outvar.nullable and (
# A nullable var has bubbled up to the top
query is ctx.toplevel_stmt
# The cardinality is being overridden, so we need to make
# sure there aren't extra NULLs in single set
or stmt.card_inference_override
# There is a LIMIT or OFFSET clause and NULLs would interfere
or stmt.limit
or stmt.offset
):
valvar = pathctx.get_path_value_var(
query, stmt.result.path_id, env=ctx.env)
if isinstance(valvar, pgast.TupleVar):
valvar = pgast.ImplicitRowExpr(
args=[e.val for e in valvar.elements])

query.where_clause = astutils.extend_binop(
query.where_clause,
pgast.NullTest(arg=valvar, negated=True)
)
output.add_null_test(valvar, query)

# The OFFSET clause
query.limit_offset = clauses.compile_limit_offset_clause(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_edgeql_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9188,6 +9188,15 @@ async def test_edgeql_assert_single_01(self):
);
""")

await self.con.query("""
select {
xy := assert_single({<optional str>$0, <optional str>$1}) };
""", None, None)
await self.con.query("""
select {
xy := assert_single({<optional str>$0, <optional str>$1}) };
""", None, 'test')

async def test_edgeql_assert_single_02(self):
await self.con.execute("""
FOR name IN {"Hunter B-15", "Hunter B-22"}
Expand Down
135 changes: 135 additions & 0 deletions tests/test_edgeql_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6515,3 +6515,138 @@ async def test_edgeql_insert_coalesce_05(self):
'select count(Note)',
[3],
)

async def test_edgeql_insert_coalesce_nulls_01(self):
Q = '''
with name := 'name',
new := (
(select Person filter .name = name) ??
(insert Person { name := name})
),
select { new := new }
'''

await self.assert_query_result(
Q,
[{'new': {}}],
)

await self.assert_query_result(
Q,
[{'new': {}}],
)

async def test_edgeql_insert_coalesce_nulls_02(self):
Q = '''
with name := 'name',
new := (
(select Person filter .name = name) ??
(insert Person { name := name})
),
select (
insert Note { name := '??', subject := new }
) { subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)

async def test_edgeql_insert_coalesce_nulls_03(self):
await self.con.execute('''
insert Note { name := 'x' }
''')

Q = '''
with name := 'name',
new := (
(select Person filter .name = name) ??
(insert Person { name := name})
),
select (update Note filter .name = 'x' set { subject := new })
{ subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)

async def test_edgeql_insert_coalesce_nulls_04(self):
Q = '''
with name := 'name',
new := (
(select Note filter .name = name) ??
(insert Note { name := name })
),
select { new := assert_single(new) }
'''

await self.assert_query_result(
Q,
[{'new': {}}],
)

await self.assert_query_result(
Q,
[{'new': {}}],
)

async def test_edgeql_insert_coalesce_nulls_05(self):
Q = '''
with name := 'name',
new := (
(select Note filter .name = name) ??
(insert Note { name := name})
),
select (
insert Note { name := '??', subject := assert_single(new) }
) { subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)

async def test_edgeql_insert_coalesce_nulls_06(self):
await self.con.execute('''
insert Note { name := 'x' }
''')

Q = '''
with name := 'name',
new := (
(select Note filter .name = name) ??
(insert Note { name := name })
),
select (update Note filter .name = 'x' set {
subject := assert_single(new) })
{ subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)
17 changes: 17 additions & 0 deletions tests/test_edgeql_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,23 @@ async def test_edgeql_select_limit_10(self):
SELECT 1 LIMIT -1
""")

async def test_edgeql_select_limit_11(self):
await self.assert_query_result(
r'''
SELECT (SELECT {<optional str>$0, 'x'} LIMIT 1)
''',
['x'],
variables=(None,),
)

await self.assert_query_result(
r'''
SELECT (SELECT {<optional str>$0, 'x'} OFFSET 1)
''',
[],
variables=(None,),
)

async def test_edgeql_select_offset_01(self):
with self.assertRaisesRegex(
edgedb.InvalidValueError,
Expand Down

0 comments on commit 54e33cd

Please sign in to comment.