Skip to content

Commit

Permalink
Fix named-tuple params that also appear in the schema (#5964)
Browse files Browse the repository at this point in the history
The tuple argument decoder code simply ignores whether a tuple is
named, figuring that named tuples and unnamed tuples have the same
representation in the generated SQL, so why bother. (It was called out
as a "HACK" in the comments.)

But they only have the same representation when they are represented
as a `record`. If the tuple type appears in the schema, it gets
represented as a named composite type and we no longer have the right
representation.

Just do it right. Fixes #5789.
  • Loading branch information
msullivan committed Sep 1, 2023
1 parent 21b5381 commit 0522bee
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 8 deletions.
27 changes: 22 additions & 5 deletions edb/edgeql/compiler/tuple_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from typing import *

from edb import errors
from edb.common.typeutils import not_none

from edb.ir import ast as irast
from edb.ir import typeutils as irtypeutils
Expand All @@ -121,7 +122,7 @@

def _lmost_is_array(typ: irast.ParamTransType) -> bool:
while isinstance(typ, irast.ParamTuple):
typ = typ.typs[0]
_, typ = typ.typs[0]
return isinstance(typ, irast.ParamArray)


Expand Down Expand Up @@ -166,13 +167,14 @@ def trans(
)

elif irtypeutils.is_tuple(typ):
# HACK: We don't bother dealing with named tuples. We just
# generate anonymous tuples and it all still works out.
return irast.ParamTuple(
typeref=typ,
idx=start,
typs=tuple(
trans(t, in_array=in_array, depth=depth + 1)
(
t.element_name,
trans(t, in_array=in_array, depth=depth + 1),
)
for t in typ.subtypes
),
)
Expand Down Expand Up @@ -222,6 +224,21 @@ def _index(expr: qlast.Expr, idx: qlast.Expr) -> qlast.Indirection:
return qlast.Indirection(arg=expr, indirection=[qlast.Index(index=idx)])


def _make_tuple(
fields: Sequence[tuple[Optional[str], qlast.Expr]]
) -> qlast.NamedTuple | qlast.Tuple:
is_named = fields and fields[0][0]
if is_named:
return qlast.NamedTuple(elements=[
qlast.TupleElement(name=qlast.ObjectRef(name=not_none(f)), val=e)
for f, e in fields
])
else:
return qlast.Tuple(
elements=[e for _, e in fields]
)


def make_decoder(
ptyp: irast.ParamTransType,
qparams: tuple[irast.Param, ...],
Expand Down Expand Up @@ -252,7 +269,7 @@ def mk(typ: irast.ParamTransType, idx: Optional[qlast.Expr]) -> qlast.Expr:
return expr

elif isinstance(typ, irast.ParamTuple):
return qlast.Tuple(elements=[mk(t, idx=idx) for t in typ.typs])
return _make_tuple([(f, mk(t, idx=idx)) for f, t in typ.typs])

elif isinstance(typ, irast.ParamArray):
inner_idx_alias, inner_idx = _get_alias('idx', ctx=ctx)
Expand Down
4 changes: 2 additions & 2 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,12 @@ def flatten(self) -> tuple[typing.Any, ...]:

@dataclasses.dataclass(eq=False)
class ParamTuple(ParamTransType):
typs: tuple[ParamTransType, ...]
typs: tuple[tuple[typing.Optional[str], ParamTransType], ...]

def flatten(self) -> tuple[typing.Any, ...]:
return (
(int(qltypes.TypeTag.TUPLE), self.idx)
+ tuple(x.flatten() for x in self.typs)
+ tuple(x.flatten() for _, x in self.typs)
)


Expand Down
3 changes: 2 additions & 1 deletion tests/schemas/casts.esdl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Test {
property p_float64 -> float64;
property p_bigint -> bigint;
property p_decimal -> decimal;
property p_tup -> tuple<test: str>;
}

type JSONTest {
Expand Down Expand Up @@ -79,4 +80,4 @@ type ScalarTest {

type Person {
property name -> str;
}
}
37 changes: 37 additions & 0 deletions tests/test_edgeql_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,6 +2883,43 @@ async def test_edgeql_casts_tuple_params_08(self):
variables=(None, 11111),
)

async def test_edgeql_casts_tuple_params_09(self):
await self.con.query('''
WITH
p := <tuple<test: str>>$0
insert Test { p_tup := p };
''', ('foo',))

await self.assert_query_result(
'''
select Test { p_tup } filter exists .p_tup
''',
[{'p_tup': {'test': 'foo'}}],
)
await self.assert_query_result(
'''
WITH
p := <tuple<test: str>>$0
select p
''',
[{'test': 'foo'}],
variables=(('foo',),),
)
await self.assert_query_result(
'''
select <tuple<test: str>>$0
''',
[{'test': 'foo'}],
variables=(('foo',),),
)
await self.assert_query_result(
'''
select <array<tuple<test: str>>>$0
''',
[[{'test': 'foo'}, {'test': 'bar'}]],
variables=([('foo',), ('bar',)],),
)

async def test_edgeql_cast_empty_set_to_array_01(self):
await self.assert_query_result(
r'''
Expand Down

0 comments on commit 0522bee

Please sign in to comment.