Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some bugs involving union and coalescing of optional values #6590

Merged
merged 5 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions edb/edgeql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def _compile_dml_coalesce(
res = dispatch.compile(full, ctx=subctx)
# Indicate that the original ?? code should determine the
# cardinality/multiplicity.
res.card_inference_override = ir
assert isinstance(res.expr, irast.SelectStmt)
res.expr.card_inference_override = ir

return res

Expand Down Expand Up @@ -479,7 +480,8 @@ def _compile_dml_ifelse(
res = dispatch.compile(full, ctx=subctx)
# Indicate that the original IF/ELSE code should determine the
# cardinality/multiplicity.
res.card_inference_override = ir
assert isinstance(res.expr, irast.SelectStmt)
res.expr.card_inference_override = ir

return res

Expand Down
11 changes: 5 additions & 6 deletions edb/edgeql/compiler/inference/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,12 +519,6 @@ def _infer_set(
ir, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

# But actually! Check if it is overridden
if ir.card_inference_override:
result = _infer_set_inner(
ir.card_inference_override, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

# We need to cache the main result before doing the shape,
# since sometimes the shape will refer to the enclosing set.
ctx.inferred_cardinality[ir] = result
Expand Down Expand Up @@ -1273,6 +1267,11 @@ def __infer_select_stmt(
if ir.iterator_stmt:
stmt_card = cartesian_cardinality((stmt_card, iter_card))

# But actually! Check if it is overridden
if ir.card_inference_override:
stmt_card = infer_cardinality(
ir.card_inference_override, scope_tree=scope_tree, ctx=ctx)

return stmt_card


Expand Down
13 changes: 6 additions & 7 deletions edb/edgeql/compiler/inference/multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,6 @@ def _infer_set(
result = _infer_set_inner(
ir, is_mutation=is_mutation, scope_tree=scope_tree, ctx=ctx
)
if ir.card_inference_override:
result = _infer_set_inner(
ir.card_inference_override, is_mutation=is_mutation,
scope_tree=scope_tree, ctx=ctx)

ctx.inferred_multiplicity[ir, scope_tree, ctx.distinct_iterator] = result

# The shape doesn't affect multiplicity, but requires validation.
Expand Down Expand Up @@ -625,7 +620,7 @@ def __infer_select_stmt(
) -> inf_ctx.MultiplicityInfo:

if ir.iterator_stmt is not None:
return _infer_for_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)
stmt_mult = _infer_for_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)
else:
stmt_mult = _infer_stmt_multiplicity(
ir, scope_tree=scope_tree, ctx=ctx)
Expand All @@ -639,7 +634,11 @@ def __infer_select_stmt(
new_scope = inf_utils.get_set_scope(clause, scope_tree, ctx=ctx)
infer_multiplicity(clause, scope_tree=new_scope, ctx=ctx)

return stmt_mult
if ir.card_inference_override:
stmt_mult = infer_multiplicity(
ir.card_inference_override, scope_tree=scope_tree, ctx=ctx)

return stmt_mult


@_infer_multiplicity.register
Expand Down
12 changes: 6 additions & 6 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,6 @@ class Set(Base):
# insertions to BaseObject.
ignore_rewrites: bool = False

# An expression to use instead of this one for the purpose of
# cardinality/multiplicity inference. This is used for when something
# is desugared in a way that doesn't preserve cardinality, but we
# need to anyway.
card_inference_override: typing.Optional[Set] = None

def __repr__(self) -> str:
return f'<ir.Set \'{self.path_id}\' at 0x{id(self):x}>'

Expand Down Expand Up @@ -1076,6 +1070,12 @@ class SelectStmt(FilteredStmt):
limit: typing.Optional[Set] = None
implicit_wrapper: bool = False

# An expression to use instead of this one for the purpose of
# cardinality/multiplicity inference. This is used for when something
# is desugared in a way that doesn't preserve cardinality, but we
# need to anyway.
card_inference_override: typing.Optional[Set] = None


class GroupStmt(FilteredStmt):
subject: Set = EmptySet() # type: ignore
Expand Down
1 change: 1 addition & 0 deletions edb/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def is_trivial_select(ir_expr: irast.Base) -> TypeGuard[irast.SelectStmt]:
and ir_expr.where is None
and ir_expr.limit is None
and ir_expr.offset is None
and ir_expr.card_inference_override is None
)


Expand Down
13 changes: 13 additions & 0 deletions edb/pgsql/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,19 @@ def output_as_value(
return val


def add_null_test(expr: pgast.BaseExpr, query: pgast.SelectStmt) -> None:
if not expr.nullable:
return

while isinstance(expr, pgast.TupleVar) and expr.elements:
expr = expr.elements[0].val

query.where_clause = astutils.extend_binop(
query.where_clause,
pgast.NullTest(arg=expr, negated=True)
)


def serialize_expr_if_needed(
expr: pgast.BaseExpr, *,
path_id: irast.PathId,
Expand Down
2 changes: 2 additions & 0 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,8 @@ def process_set_as_singleton_assertion(
],
)

output.add_null_test(arg_ref, newctx.rel)

# Force Postgres to actually evaluate the result target
# by putting it into an ORDER BY.
newctx.rel.target_list.append(
Expand Down
26 changes: 14 additions & 12 deletions edb/pgsql/compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import dispatch
from . import group
from . import dml
from . import output
from . import pathctx


Expand Down Expand Up @@ -111,19 +112,20 @@ def compile_SelectStmt(
query.sort_clause = clauses.compile_orderby_clause(
stmt.orderby, ctx=octx)

if outvar.nullable and query is ctx.toplevel_stmt:
# A nullable var has bubbled up to the top,
# filter out NULLs.
valvar: pgast.BaseExpr = pathctx.get_path_value_var(
# Need to filter out NULLs in certain cases:
if outvar.nullable and (
# A nullable var has bubbled up to the top
query is ctx.toplevel_stmt
# The cardinality is being overridden, so we need to make
# sure there aren't extra NULLs in single set
or stmt.card_inference_override
# There is a LIMIT or OFFSET clause and NULLs would interfere
or stmt.limit
or stmt.offset
):
valvar = pathctx.get_path_value_var(
query, stmt.result.path_id, env=ctx.env)
if isinstance(valvar, pgast.TupleVar):
valvar = pgast.ImplicitRowExpr(
args=[e.val for e in valvar.elements])

query.where_clause = astutils.extend_binop(
query.where_clause,
pgast.NullTest(arg=valvar, negated=True)
)
output.add_null_test(valvar, query)

# The OFFSET clause
query.limit_offset = clauses.compile_limit_offset_clause(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_edgeql_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9155,6 +9155,15 @@ async def test_edgeql_assert_single_01(self):
);
""")

await self.con.query("""
select {
xy := assert_single({<optional str>$0, <optional str>$1}) };
""", None, None)
await self.con.query("""
select {
xy := assert_single({<optional str>$0, <optional str>$1}) };
""", None, 'test')

async def test_edgeql_assert_single_02(self):
await self.con.execute("""
FOR name IN {"Hunter B-15", "Hunter B-22"}
Expand Down
135 changes: 135 additions & 0 deletions tests/test_edgeql_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6515,3 +6515,138 @@ async def test_edgeql_insert_coalesce_05(self):
'select count(Note)',
[3],
)

async def test_edgeql_insert_coalesce_nulls_01(self):
Q = '''
with name := 'name',
new := (
(select Person filter .name = name) ??
(insert Person { name := name})
),
select { new := new }
'''

await self.assert_query_result(
Q,
[{'new': {}}],
)

await self.assert_query_result(
Q,
[{'new': {}}],
)

async def test_edgeql_insert_coalesce_nulls_02(self):
Q = '''
with name := 'name',
new := (
(select Person filter .name = name) ??
(insert Person { name := name})
),
select (
insert Note { name := '??', subject := new }
) { subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)

async def test_edgeql_insert_coalesce_nulls_03(self):
await self.con.execute('''
insert Note { name := 'x' }
''')

Q = '''
with name := 'name',
new := (
(select Person filter .name = name) ??
(insert Person { name := name})
),
select (update Note filter .name = 'x' set { subject := new })
{ subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)

async def test_edgeql_insert_coalesce_nulls_04(self):
Q = '''
with name := 'name',
new := (
(select Note filter .name = name) ??
(insert Note { name := name })
),
select { new := assert_single(new) }
'''

await self.assert_query_result(
Q,
[{'new': {}}],
)

await self.assert_query_result(
Q,
[{'new': {}}],
)

async def test_edgeql_insert_coalesce_nulls_05(self):
Q = '''
with name := 'name',
new := (
(select Note filter .name = name) ??
(insert Note { name := name})
),
select (
insert Note { name := '??', subject := assert_single(new) }
) { subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)

async def test_edgeql_insert_coalesce_nulls_06(self):
await self.con.execute('''
insert Note { name := 'x' }
''')

Q = '''
with name := 'name',
new := (
(select Note filter .name = name) ??
(insert Note { name := name })
),
select (update Note filter .name = 'x' set {
subject := assert_single(new) })
{ subject }
'''

await self.assert_query_result(
Q,
[{'subject': {}}],
)

await self.assert_query_result(
Q,
[{'subject': {}}],
)
17 changes: 17 additions & 0 deletions tests/test_edgeql_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,23 @@ async def test_edgeql_select_limit_10(self):
SELECT 1 LIMIT -1
""")

async def test_edgeql_select_limit_11(self):
await self.assert_query_result(
r'''
SELECT (SELECT {<optional str>$0, 'x'} LIMIT 1)
''',
['x'],
variables=(None,),
)

await self.assert_query_result(
r'''
SELECT (SELECT {<optional str>$0, 'x'} OFFSET 1)
''',
[],
variables=(None,),
)

async def test_edgeql_select_offset_01(self):
with self.assertRaisesRegex(
edgedb.InvalidValueError,
Expand Down