From 3e0832cb1f631fade5e3efebc904d61576c4f1ca Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 30 Oct 2018 22:05:16 -0400 Subject: [PATCH] Allow indexing named tuples by position; fix named tuples serialization Named tuples can now be indexed by position. The below two queries are equivalent: db> SELECT (a:=1).a; db> SELECT (a:=1).0; Named tuples will now be correctly serialized to JSON in more cases, for example: db> SELECT [(a:=1)][0]; {"a": 1} Serialization of nested named tuples is still broken in some cases though: db> SELECT [(a:=(x:=1))][0]; {"a": {"f1": 1}} --- edb/lang/edgeql/compiler/setgen.py | 19 ++-- edb/lang/ir/inference/types.py | 5 +- edb/lang/schema/_std.eql | 2 +- edb/lang/schema/nodes.py | 5 ++ edb/lang/schema/objtypes.py | 3 + edb/lang/schema/types.py | 67 +++++++++++++- edb/server/pgsql/compiler/output.py | 41 ++++++++- edb/server/pgsql/compiler/relgen.py | 10 +-- edb/server/pgsql/compiler/typecomp.py | 120 +++++++++++++------------- tests/test_edgeql_calls.py | 58 +++++++++++++ tests/test_edgeql_expressions.py | 54 ++++++++++++ tests/test_edgeql_functions.py | 47 +++++++++- 12 files changed, 343 insertions(+), 88 deletions(-) diff --git a/edb/lang/edgeql/compiler/setgen.py b/edb/lang/edgeql/compiler/setgen.py index e3b4372e81..7a99df691a 100644 --- a/edb/lang/edgeql/compiler/setgen.py +++ b/edb/lang/edgeql/compiler/setgen.py @@ -461,17 +461,14 @@ def tuple_indirection_set( else: el_name = ptr_name[1] - if el_name in source.element_types: - path_id = irutils.tuple_indirection_path_id( - path_tip.path_id, el_name, - source.element_types[el_name]) - expr = irast.TupleIndirection( - expr=path_tip, name=el_name, path_id=path_id, - context=source_context) - else: - raise errors.EdgeQLReferenceError( - f'{el_name} is not a member of a tuple', - context=source_context) + el_norm_name = source.normalize_index(el_name) + el_type = source.get_subtype(el_name) + + path_id = irutils.tuple_indirection_path_id( + path_tip.path_id, el_norm_name, el_type) + expr = irast.TupleIndirection( + expr=path_tip, name=el_norm_name, path_id=path_id, + context=source_context) return generated_set(expr, ctx=ctx) diff --git a/edb/lang/ir/inference/types.py b/edb/lang/ir/inference/types.py index 4d690a7028..8d8856a6c8 100644 --- a/edb/lang/ir/inference/types.py +++ b/edb/lang/ir/inference/types.py @@ -39,8 +39,7 @@ def amend_empty_set_type(es: irast.EmptySet, t: s_obj.Object, schema) -> None: alias = es.path_id.target.name.name scls_name = s_name.Name(module='__expr__', name=alias) - scls = t.__class__(name=scls_name, bases=[t]) - scls.acquire_ancestor_inheritance(schema) + scls = t.derive_subtype(schema, name=scls_name) es.path_id = irast.PathId(scls) es.scls = t @@ -441,7 +440,7 @@ def __infer_struct(ir, schema): @_infer_type.register(irast.TupleIndirection) def __infer_struct_indirection(ir, schema): struct_type = infer_type(ir.expr, schema) - result = struct_type.element_types.get(ir.name) + result = struct_type.get_subtype(ir.name) if result is None: raise ql_errors.EdgeQLError('could not determine struct element type', context=ir.context) diff --git a/edb/lang/schema/_std.eql b/edb/lang/schema/_std.eql index f8f2edff5c..621bb9a4c5 100644 --- a/edb/lang/schema/_std.eql +++ b/edb/lang/schema/_std.eql @@ -115,7 +115,7 @@ CREATE FUNCTION std::array_contains(array: array, $$; CREATE FUNCTION std::array_enumerate(array: array) -> - SET OF tuple + SET OF tuple FROM SQL FUNCTION '--system--'; CREATE FUNCTION std::array_get(array: array, diff --git a/edb/lang/schema/nodes.py b/edb/lang/schema/nodes.py index d6beeab884..6aa1479a79 100644 --- a/edb/lang/schema/nodes.py +++ b/edb/lang/schema/nodes.py @@ -35,6 +35,11 @@ def material_type(self): t = t.bases[0] return t + def derive_subtype(self, schema, *, name: str) -> s_types.Type: + st = type(self)(name=name, bases=[self]) + st.acquire_ancestor_inheritance(schema) + return st + def peel_view(self): if self.is_view(): return self.bases[0] diff --git a/edb/lang/schema/objtypes.py b/edb/lang/schema/objtypes.py index 9b626e2329..2107b14d13 100644 --- a/edb/lang/schema/objtypes.py +++ b/edb/lang/schema/objtypes.py @@ -39,6 +39,9 @@ class SourceNode(sources.Source, nodes.Node): class ObjectType(SourceNode, constraints.ConsistencySubject): _type = 'ObjectType' + def is_object_type(self): + return True + def materialize_policies(self, schema): bases = self.bases diff --git a/edb/lang/schema/types.py b/edb/lang/schema/types.py index 8bbe76d5e2..45fb28a72e 100644 --- a/edb/lang/schema/types.py +++ b/edb/lang/schema/types.py @@ -53,9 +53,15 @@ class Type(so.NamedObject, derivable.DerivableObjectBase): # rptr will contain the inbound pointer class. rptr = so.Field(so.Object, default=None, compcoef=0.909) + def derive_subtype(self, schema, *, name: str) -> 'Type': + raise NotImplementedError + def is_type(self): return True + def is_object_type(self): + return False + def is_polymorphic(self): return False @@ -298,6 +304,12 @@ def is_array(self): def get_container(self): return tuple + def derive_subtype(self, schema, *, name: str) -> Type: + return Array.from_subtypes( + self.element_type, + self.get_typemods(), + name=name) + def get_subtypes(self): return (self.element_type,) @@ -365,7 +377,7 @@ def _test_polymorphic(self, other: 'Type'): return self.element_type.test_polymorphic(other.element_type) @classmethod - def from_subtypes(cls, subtypes, typemods=None): + def from_subtypes(cls, subtypes, typemods=None, *, name=None): if len(subtypes) != 1: raise s_err.SchemaError( f'unexpected number of subtypes, expecting 1: {subtypes!r}') @@ -389,7 +401,7 @@ def from_subtypes(cls, subtypes, typemods=None): element_type = stype dimensions = [] - return cls(element_type=element_type, dimensions=dimensions) + return cls(element_type=element_type, dimensions=dimensions, name=name) def __hash__(self): return hash(( @@ -422,14 +434,61 @@ def is_tuple(self): def get_container(self): return dict + def iter_subtypes(self): + yield from self.element_types.items() + + def normalize_index(self, field: str) -> str: + if self.named and field.isdecimal(): + idx = int(field) + if idx >= 0 and idx < len(self.element_types): + return list(self.element_types.keys())[idx] + else: + raise s_err.ItemNotFoundError( + f'{field} is not a member of {self.displayname}') + + return field + + def index_of(self, field: str) -> int: + if field.isdecimal(): + idx = int(field) + if idx >= 0 and idx < len(self.element_types): + if self.named: + return list(self.element_types.keys()).index(field) + else: + return idx + elif self.named and field in self.element_types: + return list(self.element_types.keys()).index(field) + + raise s_err.ItemNotFoundError( + f'{field} is not a member of {self.displayname}') + + def get_subtype(self, field: str) -> Type: + # index can be a name or a position + if field.isdecimal(): + idx = int(field) + if idx >= 0 and idx < len(self.element_types): + return list(self.element_types.values())[idx] + + elif self.named and field in self.element_types: + return self.element_types[field] + + raise s_err.ItemNotFoundError( + f'{field} is not a member of {self.displayname}') + def get_subtypes(self): if self.element_types: return list(self.element_types.values()) else: return [] + def derive_subtype(self, schema, *, name: str) -> Type: + return Tuple.from_subtypes( + self.element_types, + self.get_typemods(), + name=name) + @classmethod - def from_subtypes(cls, subtypes, typemods=None): + def from_subtypes(cls, subtypes, typemods=None, *, name: str=None): named = False if typemods is not None: named = typemods.get('named', False) @@ -441,7 +500,7 @@ def from_subtypes(cls, subtypes, typemods=None): else: types = subtypes - return cls(element_types=types, named=named) + return cls(element_types=types, named=named, name=name) def implicitly_castable_to(self, other: Type, schema) -> bool: if not other.is_tuple(): diff --git a/edb/server/pgsql/compiler/output.py b/edb/server/pgsql/compiler/output.py index 25cb0b0d77..a16b0147c6 100644 --- a/edb/server/pgsql/compiler/output.py +++ b/edb/server/pgsql/compiler/output.py @@ -22,6 +22,7 @@ import typing from edb.lang.ir import ast as irast + from edb.lang.schema import objtypes as s_objtypes from edb.lang.schema import types as s_types @@ -29,6 +30,34 @@ from edb.server.pgsql import types as pgtypes from . import context +from . import typecomp + + +def named_tuple_as_json_object(expr, *, stype, env): + assert stype.is_tuple() and stype.named + + keyvals = [] + for el_idx, (el_name, el_type) in enumerate(stype.iter_subtypes()): + keyvals.append(pgast.StringConstant(val=el_name)) + + type_sentinel = typecomp.cast( + pgast.NullConstant(), + source_type=el_type, target_type=el_type, force=True, + env=env) + + val = pgast.FuncCall( + name=('edgedb', 'row_getattr_by_num'), + args=[ + expr, + pgast.NumericConstant(val=str(el_idx + 1)), + type_sentinel + ]) + + keyvals.append(val) + + return pgast.FuncCall( + name=('jsonb_build_object',), + args=keyvals, null_safe=True, nullable=expr.nullable) def tuple_var_as_json_object(tvar, *, path_id, env): @@ -113,10 +142,14 @@ def serialize_expr_to_json( name=('jsonb_build_array',), args=expr.args, null_safe=True) - elif isinstance(path_id.target, s_types.Tuple): - val = pgast.FuncCall( - name=('edgedb', 'row_to_jsonb_array',), args=[expr], - null_safe=True) + elif path_id.target.is_tuple(): + if path_id.target.named: + val = named_tuple_as_json_object( + expr, stype=path_id.target, env=env) + else: + val = pgast.FuncCall( + name=('edgedb', 'row_to_jsonb_array',), args=[expr], + null_safe=True) elif not nested: val = pgast.FuncCall( diff --git a/edb/server/pgsql/compiler/relgen.py b/edb/server/pgsql/compiler/relgen.py index d39e4ce91e..a2bb85c7a7 100644 --- a/edb/server/pgsql/compiler/relgen.py +++ b/edb/server/pgsql/compiler/relgen.py @@ -1163,7 +1163,7 @@ def process_set_as_tuple( for element in expr.elements: path_id = irutils.tuple_indirection_path_id( ir_set.path_id, element.name, - ir_set.scls.element_types[element.name] + ir_set.scls.get_subtype(element.name) ) stmt.view_path_id_map[path_id] = element.val.path_id @@ -1218,10 +1218,8 @@ def process_set_as_tuple_indirection( source_type=ir_set.scls, target_type=ir_set.scls, force=True, env=subctx.env) - tuple_atts = list(tuple_set.scls.element_types.keys()) - att_idx = pgast.NumericConstant( - val=str(tuple_atts.index(ir_set.expr.name) + 1) - ) + index = tuple_set.scls.index_of(ir_set.expr.name) + att_idx = pgast.NumericConstant(val=str(index + 1)) set_expr = pgast.FuncCall( name=('edgedb', 'row_getattr_by_num'), @@ -1373,7 +1371,7 @@ def process_set_as_func_expr( elements=[ pgast.TupleElement( path_id=irutils.tuple_indirection_path_id( - ir_set.path_id, n, rtype.element_types[n], + ir_set.path_id, n, rtype.get_subtype(n), ), name=n, val=dbobj.get_column( diff --git a/edb/server/pgsql/compiler/typecomp.py b/edb/server/pgsql/compiler/typecomp.py index ae5ee5d168..6deb05770c 100644 --- a/edb/server/pgsql/compiler/typecomp.py +++ b/edb/server/pgsql/compiler/typecomp.py @@ -26,7 +26,6 @@ from edb.lang.common import ast from edb.lang.schema import objects as s_obj -from edb.lang.schema import types as s_types from edb.server.pgsql import ast as pgast from edb.server.pgsql import types as pg_types @@ -54,69 +53,74 @@ def cast( real_t = schema.get('std::anyreal') bytes_t = schema.get('std::bytes') - if isinstance(target_type, s_types.Collection): - if isinstance(target_type, s_types.Array): - - if source_type.issubclass(json_t): - # If we are casting a jsonb array to array, we do the - # following transformation: - # EdgeQL: >MAP_VALUE - # SQL: - # SELECT array_agg(j::T) - # FROM jsonb_array_elements(MAP_VALUE) AS j - - inner_cast = cast( - pgast.ColumnRef(name=['j']), - source_type=source_type, - target_type=target_type.element_type, - env=env - ) + if target_type.is_array(): + if source_type.issubclass(json_t): + # If we are casting a jsonb array to array, we do the + # following transformation: + # EdgeQL: >MAP_VALUE + # SQL: + # SELECT array_agg(j::T) + # FROM jsonb_array_elements(MAP_VALUE) AS j + + inner_cast = cast( + pgast.ColumnRef(name=['j']), + source_type=source_type, + target_type=target_type.element_type, + env=env + ) - return pgast.SelectStmt( - target_list=[ - pgast.ResTarget( - val=pgast.FuncCall( - name=('array_agg',), - args=[ - inner_cast - ]) - ) - ], - from_clause=[ - pgast.RangeFunction( - functions=[pgast.FuncCall( - name=('jsonb_array_elements',), - args=[ - node - ] - )], - alias=pgast.Alias( - aliasname='j' - ) + return pgast.SelectStmt( + target_list=[ + pgast.ResTarget( + val=pgast.FuncCall( + name=('array_agg',), + args=[ + inner_cast + ]) + ) + ], + from_clause=[ + pgast.RangeFunction( + functions=[pgast.FuncCall( + name=('jsonb_array_elements',), + args=[ + node + ] + )], + alias=pgast.Alias( + aliasname='j' ) - ]) + ) + ]) - else: - # EdgeQL: >['1', '2'] - # to SQL: ARRAY['1', '2']::int[] + else: + # EdgeQL: >['1', '2'] + # to SQL: ARRAY['1', '2']::int[] - elem_pgtype = pg_types.pg_type_from_object( - schema, target_type.element_type, topbase=True) + elem_pgtype = pg_types.pg_type_from_object( + schema, target_type.element_type, topbase=True) - if elem_pgtype in {('anyelement',), ('anynonarray',)}: - # We don't want to append '[]' suffix to - # `anyelement` and `anynonarray`. - return pgast.TypeCast( - arg=node, - type_name=pgast.TypeName( - name=('anyarray',))) + if elem_pgtype in {('anyelement',), ('anynonarray',)}: + # We don't want to append '[]' suffix to + # `anyelement` and `anynonarray`. + return pgast.TypeCast( + arg=node, + type_name=pgast.TypeName( + name=('anyarray',))) - else: - return pgast.TypeCast( - arg=node, - type_name=pgast.TypeName( - name=elem_pgtype, - array_bounds=[-1])) + else: + return pgast.TypeCast( + arg=node, + type_name=pgast.TypeName( + name=elem_pgtype, + array_bounds=[-1])) + + elif target_type.is_tuple(): + if target_type.implicitly_castable_to(source_type, env.schema): + return pgast.TypeCast( + arg=node, + type_name=pgast.TypeName( + name=('record',))) else: # `target_type` is not a collection. diff --git a/tests/test_edgeql_calls.py b/tests/test_edgeql_calls.py index 81b892c074..dddcbbe8be 100644 --- a/tests/test_edgeql_calls.py +++ b/tests/test_edgeql_calls.py @@ -1076,3 +1076,61 @@ async def test_edgeql_calls_30(self): a: anyint ); ''') + + async def test_edgeql_calls_31(self): + await self.con.execute(''' + CREATE FUNCTION test::call31( + a: any + ) -> any + FROM EdgeQL $$ + SELECT a + $$; + ''') + + try: + await self.assert_query_result(r''' + SELECT test::call31(10); + SELECT test::call31('aa'); + + SELECT test::call31([1, 2]); + SELECT test::call31([1, 2])[0]; + + SELECT test::call31((a:=1001, b:=1002)).a; + SELECT test::call31((a:=1001, b:=1002)).1; + + SELECT test::call31((a:=['a', 'b'], b:=['x', 'y'])).1; + SELECT test::call31((a:=['a', 'b'], b:=['x', 'y'])).a[1]; + + SELECT test::call31((a:=1001, b:=1002)); + + SELECT test::call31((a:=[(x:=1)])).a[0].x; + SELECT test::call31((a:=[(x:=1)])).0[0].x; + SELECT test::call31((a:=[(x:=1)])).0[0].0; + SELECT test::call31((a:=[(x:=1)])).a[0]; + ''', [ + [10], + ['aa'], + + [[1, 2]], + [1], + + [1001], + [1002], + + [['x', 'y']], + ['b'], + + [{"a": 1001, "b": 1002}], + + [1], + [1], + [1], + [{"x": 1}], + ]) + + finally: + await self.con.execute(''' + DROP FUNCTION test::call31( + a: any + ); + ''') diff --git a/tests/test_edgeql_expressions.py b/tests/test_edgeql_expressions.py index e8f7e89b04..2ba880ffc6 100644 --- a/tests/test_edgeql_expressions.py +++ b/tests/test_edgeql_expressions.py @@ -1436,6 +1436,60 @@ async def test_edgeql_expr_tuple_indirection_11(self): 2, ]]) + async def test_edgeql_expr_tuple_indirection_12(self): + await self.assert_query_result(r""" + SELECT (name := 'foo', val := 42).0; + SELECT (name := 'foo', val := 42).1; + SELECT [(name := 'foo', val := 42)][0].name; + SELECT [(name := 'foo', val := 42)][0].1; + """, [ + ['foo'], + [42], + ['foo'], + [42], + ]) + + async def test_edgeql_expr_tuple_indirection_13(self): + await self.assert_query_result(r""" + SELECT (a:=(b:=(c:=(e:=1)))); + + SELECT (a:=(b:=(c:=(e:=1)))).a; + SELECT (a:=(b:=(c:=(e:=1)))).0; + + SELECT (a:=(b:=(c:=(e:=1)))).a.b; + SELECT (a:=(b:=(c:=(e:=1)))).0.0; + + SELECT (a:=(b:=(c:=(e:=1)))).a.b.c; + SELECT (a:=(b:=(c:=(e:=1)))).0.0.0; + + SELECT (a:=(b:=(c:=(e:=1)))).a.b.c.e; + SELECT (a:=(b:=(c:=(e:=1)))).0.b.c.0; + """, [ + [{"a": {"b": {"c": {"e": 1}}}}], + + [{"b": {"c": {"e": 1}}}], + [{"b": {"c": {"e": 1}}}], + + [{"c": {"e": 1}}], + [{"c": {"e": 1}}], + + [{"e": 1}], + [{"e": 1}], + + [1], + [1], + ]) + + @unittest.expectedFailure + async def test_edgeql_expr_tuple_indirection_14(self): + await self.assert_query_result(r""" + SELECT [(a:=(b:=(c:=(e:=1))))][0].a; + SELECT [(a:=(b:=(c:=(e:=1))))][0].0; + """, [ + [{"b": {"c": {"e": 1}}}], + [{"b": {"c": {"e": 1}}}], + ]) + async def test_edgeql_expr_cannot_assign_dunder_type_01(self): with self.assertRaisesRegex( exc.EdgeQLError, r'cannot assign to __type__'): diff --git a/tests/test_edgeql_functions.py b/tests/test_edgeql_functions.py index 5409a839a1..dfa38f5400 100644 --- a/tests/test_edgeql_functions.py +++ b/tests/test_edgeql_functions.py @@ -393,17 +393,41 @@ async def test_edgeql_functions_array_enumerate_01(self): SELECT [10, 20]; SELECT array_enumerate([10,20]); SELECT array_enumerate([10,20]).1 + 100; + SELECT array_enumerate([10,20]).index + 100; ''', [ [[10, 20]], - [[10, 0], [20, 1]], + [{"element": 10, "index": 0}, {"element": 20, "index": 1}], + [100, 101], [100, 101], ]) async def test_edgeql_functions_array_enumerate_02(self): await self.assert_query_result(r''' SELECT array_enumerate([10,20]).0 + 100; + SELECT array_enumerate([10,20]).element + 1000; ''', [ [110, 120], + [1010, 1020], + ]) + + @unittest.expectedFailure + async def test_edgeql_functions_array_enumerate_03(self): + await self.assert_query_result(r''' + SELECT array_enumerate([(x:=1)]).0; + SELECT array_enumerate([(x:=1)]).0.x; + + SELECT array_enumerate([(x:=(a:=2))]).0; + SELECT array_enumerate([(x:=(a:=2))]).0.x; + + SELECT array_enumerate([(x:=(a:=2))]).0.x.a; + ''', [ + [{"x": 1}], + [1], + + [{"x": {"a": 2}}], + [{"a": 2}], + + [2], ]) async def test_edgeql_functions_array_get_01(self): @@ -473,6 +497,27 @@ async def test_edgeql_functions_array_get_05(self): [4200], ]) + async def test_edgeql_functions_array_get_06(self): + await self.assert_query_result(r''' + SELECT array_get([(20,), (30,)], 0); + SELECT array_get([(a:=20), (a:=30)], 1); + + SELECT array_get([(20,), (30,)], 0).0; + SELECT array_get([(a:=20), (a:=30)], 1).0; + + SELECT array_get([(a:=20, b:=1), (a:=30, b:=2)], 0).a; + SELECT array_get([(a:=20, b:=1), (a:=30, b:=2)], 1).b; + ''', [ + [[20]], + [{'a': 30}], # XXX + + [20], + [30], + + [20], + [2], + ]) + async def test_edgeql_functions_re_match_01(self): await self.assert_query_result(r''' SELECT re_match('ab', 'AbabaB');