Skip to content

Commit

Permalink
[3.x] Fix some bugs involving union and coalescing of optional values (
Browse files Browse the repository at this point in the history
…#6590)

The original 4.x/5.x version had a lot to do with DML coalescing.

There are a number of places that improperly handle a set that has
multiple NULL values in it:
 * assert_single
 * limit/offset

The assert_single and limit case is a long standing bug that can be triggered
without much trouble.

Fix the long-standing cases by adding null checks, but fix the
coalesce cases by adding a null check at the producer side.
The rule then, is that there can be NULLs in any produced set
that can be empty, but a single set has to return at most one row
(so it can't have extra NULLs).
  • Loading branch information
msullivan committed Dec 19, 2023
1 parent 2c37376 commit 1dffedd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 12 deletions.
13 changes: 13 additions & 0 deletions edb/pgsql/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,19 @@ def output_as_value(
return val


def add_null_test(expr: pgast.BaseExpr, query: pgast.SelectStmt) -> None:
if not expr.nullable:
return

while isinstance(expr, pgast.TupleVar) and expr.elements:
expr = expr.elements[0].val

query.where_clause = astutils.extend_binop(
query.where_clause,
pgast.NullTest(arg=expr, negated=True)
)


def serialize_expr_if_needed(
expr: pgast.BaseExpr, *,
path_id: irast.PathId,
Expand Down
2 changes: 2 additions & 0 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,8 @@ def process_set_as_singleton_assertion(
],
)

output.add_null_test(arg_ref, newctx.rel)

# Force Postgres to actually evaluate the result target
# by putting it into an ORDER BY.
newctx.rel.target_list.append(
Expand Down
23 changes: 11 additions & 12 deletions edb/pgsql/compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import dispatch
from . import group
from . import dml
from . import output
from . import pathctx


Expand Down Expand Up @@ -111,19 +112,17 @@ def compile_SelectStmt(
query.sort_clause = clauses.compile_orderby_clause(
stmt.orderby, ctx=octx)

if outvar.nullable and query is ctx.toplevel_stmt:
# A nullable var has bubbled up to the top,
# filter out NULLs.
valvar: pgast.BaseExpr = pathctx.get_path_value_var(
# Need to filter out NULLs in certain cases:
if outvar.nullable and (
# A nullable var has bubbled up to the top
query is ctx.toplevel_stmt
# There is a LIMIT or OFFSET clause and NULLs would interfere
or stmt.limit
or stmt.offset
):
valvar = pathctx.get_path_value_var(
query, stmt.result.path_id, env=ctx.env)
if isinstance(valvar, pgast.TupleVar):
valvar = pgast.ImplicitRowExpr(
args=[e.val for e in valvar.elements])

query.where_clause = astutils.extend_binop(
query.where_clause,
pgast.NullTest(arg=valvar, negated=True)
)
output.add_null_test(valvar, query)

# The OFFSET clause
query.limit_offset = clauses.compile_limit_offset_clause(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_edgeql_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7836,6 +7836,15 @@ async def test_edgeql_assert_single_01(self):
);
""")

await self.con.query("""
select {
xy := assert_single({<optional str>$0, <optional str>$1}) };
""", None, None)
await self.con.query("""
select {
xy := assert_single({<optional str>$0, <optional str>$1}) };
""", None, 'test')

async def test_edgeql_assert_single_02(self):
await self.con.execute("""
FOR name IN {"Hunter B-15", "Hunter B-22"}
Expand Down
17 changes: 17 additions & 0 deletions tests/test_edgeql_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,23 @@ async def test_edgeql_select_limit_10(self):
SELECT 1 LIMIT -1
""")

async def test_edgeql_select_limit_11(self):
await self.assert_query_result(
r'''
SELECT (SELECT {<optional str>$0, 'x'} LIMIT 1)
''',
['x'],
variables=(None,),
)

await self.assert_query_result(
r'''
SELECT (SELECT {<optional str>$0, 'x'} OFFSET 1)
''',
[],
variables=(None,),
)

async def test_edgeql_select_offset_01(self):
with self.assertRaisesRegex(
edgedb.InvalidValueError,
Expand Down

0 comments on commit 1dffedd

Please sign in to comment.