diff --git a/edgedb/codegen/generator.py b/edgedb/codegen/generator.py index 1605636a..65720066 100644 --- a/edgedb/codegen/generator.py +++ b/edgedb/codegen/generator.py @@ -442,7 +442,7 @@ def _generate_code( print(f"{INDENT}@typing.overload", file=buf) print( f'{INDENT}def __getitem__' - f'(self, key: {typing_literal}["@{el_name}"]) ' + f'(self, key: {typing_literal}["{el_name}"]) ' f'-> {el_code}:', file=buf, ) diff --git a/edgedb/datatypes/datatypes.pyx b/edgedb/datatypes/datatypes.pyx index 4cc26cf4..f287b6c4 100644 --- a/edgedb/datatypes/datatypes.pyx +++ b/edgedb/datatypes/datatypes.pyx @@ -54,12 +54,11 @@ def create_object_factory(**pointers): names = () fields = {} for pname, ptype in pointers.items(): - names += (pname,) - if not isinstance(ptype, set): ptype = {ptype} flag = 0 + is_linkprop = False for pt in ptype: if pt == 'link': flag |= EDGE_POINTER_IS_LINK @@ -67,16 +66,21 @@ def create_object_factory(**pointers): pass elif pt == 'link-property': flag |= EDGE_POINTER_IS_LINKPROP + is_linkprop = True elif pt == 'implicit': flag |= EDGE_POINTER_IS_IMPLICIT else: raise ValueError(f'unknown pointer type {pt}') + if is_linkprop: + names += ("@" + pname,) + else: + names += (pname,) + field = dataclasses.field() + field.name = pname + field._field_type = dataclasses._FIELD + fields[pname] = field flags += (flag,) - field = dataclasses.field() - field.name = pname - field._field_type = dataclasses._FIELD - fields[pname] = field desc = EdgeRecordDesc_New(names, flags, NULL) size = len(pointers) diff --git a/edgedb/datatypes/link.c b/edgedb/datatypes/link.c index 1a49d6b4..02e70a6a 100644 --- a/edgedb/datatypes/link.c +++ b/edgedb/datatypes/link.c @@ -24,6 +24,7 @@ static int init_type_called = 0; static Py_hash_t base_hash = -1; +extern PyObject* at_sign_ptr; PyObject * @@ -276,7 +277,12 @@ link_getattr(EdgeLinkObject *o, PyObject *name) assert(EdgeRecordDesc_Check(desc)); Py_ssize_t pos; - edge_attr_lookup_t ret = EdgeRecordDesc_Lookup(desc, name, &pos); + PyObject *prefixed_name = PyUnicode_Concat(at_sign_ptr, name); + if (prefixed_name == NULL) { + return NULL; + } + edge_attr_lookup_t ret = EdgeRecordDesc_Lookup(desc, prefixed_name, &pos); + Py_DECREF(prefixed_name); switch (ret) { case L_ERROR: return NULL; @@ -313,6 +319,18 @@ link_dir(EdgeLinkObject *o, PyObject *args) return NULL; } + PyObject *name, *stripped; + for (Py_ssize_t i = 0; i < PyList_GET_SIZE(ret); i++) { + name = PyList_GET_ITEM(ret, i); + stripped = PyUnicode_Substring(name, 1, PyUnicode_GET_LENGTH(name)); + if (stripped == NULL) { + Py_DECREF(ret); + return NULL; + } + PyList_SET_ITEM(ret, i, stripped); + Py_DECREF(name); + } + PyObject *str = PyUnicode_FromString("source"); if (str == NULL) { Py_DECREF(ret); diff --git a/edgedb/datatypes/object.c b/edgedb/datatypes/object.c index 03268efe..3e0050e2 100644 --- a/edgedb/datatypes/object.c +++ b/edgedb/datatypes/object.c @@ -196,50 +196,10 @@ object_getattr(EdgeObject *o, PyObject *name) ) { return EdgeRecordDesc_GetDataclassFields((PyObject *)o->desc); } - - // getattr(obj, "@...") for link property - int prefixed = PyUnicode_Tailmatch( - name, at_sign_ptr, 0, PY_SSIZE_T_MAX, -1 - ); - if (prefixed == -1) { - return NULL; - } - if (prefixed) { - PyObject *stripped = PyUnicode_Substring( - name, 1, PyUnicode_GET_LENGTH(name) - ); - if (stripped == NULL) { - return NULL; - } - ret = EdgeRecordDesc_Lookup( - (PyObject *)o->desc, stripped, &pos); - Py_DECREF(stripped); - switch (ret) { - case L_ERROR: - return NULL; - - case L_NOT_FOUND: - case L_LINK: - case L_PROPERTY: - return PyObject_GenericGetAttr((PyObject *)o, name); - - case L_LINKPROP: { - PyObject *val = EdgeObject_GET_ITEM(o, pos); - Py_INCREF(val); - return val; - } - - default: - abort(); - } - } - return PyObject_GenericGetAttr((PyObject *)o, name); } case L_LINKPROP: - return PyObject_GenericGetAttr((PyObject *)o, name); - case L_LINK: case L_PROPERTY: { PyObject *val = EdgeObject_GET_ITEM(o, pos); @@ -256,77 +216,51 @@ static PyObject * object_getitem(EdgeObject *o, PyObject *name) { Py_ssize_t pos; - int prefixed = 0; - PyObject *stripped = name; - if (PyUnicode_Check(name)) { - prefixed = PyUnicode_Tailmatch( - name, at_sign_ptr, 0, PY_SSIZE_T_MAX, -1 - ); - if (prefixed == -1) { - return NULL; - } - if (prefixed) { - stripped = PyUnicode_Substring( - name, 1, PyUnicode_GET_LENGTH(name) - ); - if (stripped == NULL) { - return NULL; - } - } - } - edge_attr_lookup_t ret = EdgeRecordDesc_Lookup( - (PyObject *)o->desc, stripped, &pos + (PyObject *)o->desc, name, &pos ); - if (prefixed) { - Py_DECREF(stripped); - } switch (ret) { case L_ERROR: return NULL; case L_PROPERTY: + PyErr_Format( + PyExc_TypeError, + "property %R should be accessed via dot notation", + name); + return NULL; + + case L_LINKPROP: { + PyObject *val = EdgeObject_GET_ITEM(o, pos); + Py_INCREF(val); + return val; + } + + case L_NOT_FOUND: { + int prefixed = 0; + if (PyUnicode_Check(name)) { + prefixed = PyUnicode_Tailmatch( + name, at_sign_ptr, 0, PY_SSIZE_T_MAX, -1 + ); + if (prefixed == -1) { + return NULL; + } + } if (prefixed) { PyErr_Format( PyExc_KeyError, "link property %R does not exist", name); - } else { - PyErr_Format( - PyExc_TypeError, - "property %R should be accessed via dot notation", - name); - } - return NULL; - - case L_LINKPROP: - if (prefixed) { - PyObject *val = EdgeObject_GET_ITEM(o, pos); - Py_INCREF(val); - return val; } else { PyErr_Format( PyExc_TypeError, "link property %R should be accessed with '@' prefix", name); - return NULL; } - - case L_NOT_FOUND: - PyErr_Format( - PyExc_KeyError, - "link property %R does not exist", - name); return NULL; + } case L_LINK: { - if (prefixed) { - PyErr_Format( - PyExc_KeyError, - "link property %R does not exist", - name); - return NULL; - } int res = PyErr_WarnEx( PyExc_DeprecationWarning, "getting link on object is deprecated since 1.0, " diff --git a/edgedb/datatypes/repr.c b/edgedb/datatypes/repr.c index bfe0da4b..752a5fa0 100644 --- a/edgedb/datatypes/repr.c +++ b/edgedb/datatypes/repr.c @@ -118,12 +118,7 @@ _EdgeGeneric_RenderItems(_PyUnicodeWriter *writer, } if (is_linkprop) { - if (include_link_props) { - if (_PyUnicodeWriter_WriteChar(writer, '@') < 0) { - goto error; - } - } - else { + if (!include_link_props) { continue; } } diff --git a/edgedb/protocol/codecs/codecs.pyx b/edgedb/protocol/codecs/codecs.pyx index 3925f4ff..315d7abe 100644 --- a/edgedb/protocol/codecs/codecs.pyx +++ b/edgedb/protocol/codecs/codecs.pyx @@ -194,6 +194,8 @@ cdef class CodecsRegistry: frb_read(spec, str_len), str_len) pos = hton.unpack_int16(frb_read(spec, 2)) + if flag & datatypes._EDGE_POINTER_IS_LINKPROP: + name = "@" + name cpython.Py_INCREF(name) cpython.PyTuple_SetItem(names, i, name) diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert b/tests/codegen/test-project2/generated_async_edgeql.py.assert index 2dfbabba..b337aa8b 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert @@ -34,6 +34,7 @@ class LinkPropResult: class LinkPropResultFriendsItem: id: uuid.UUID name: str + created_at: datetime.datetime | None @typing.overload def __getitem__(self, key: typing.Literal["@created_at"]) -> datetime.datetime | None: diff --git a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert index 2b4acf25..8937f0d0 100644 --- a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert @@ -31,6 +31,7 @@ class LinkPropResult(NoPydanticValidation): class LinkPropResultFriendsItem(NoPydanticValidation): id: uuid.UUID name: str + created_at: typing.Optional[datetime.datetime] @typing.overload def __getitem__(self, key: typing_extensions.Literal["@created_at"]) -> typing.Optional[datetime.datetime]: diff --git a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert index 0d7de44e..ccbc8dec 100644 --- a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert +++ b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert @@ -21,6 +21,7 @@ class LinkPropResult: class LinkPropResultFriendsItem: id: uuid.UUID name: str + created_at: typing.Optional[datetime.datetime] @typing.overload def __getitem__(self, key: typing.Literal["@created_at"]) -> typing.Optional[datetime.datetime]: diff --git a/tests/datatypes/test_datatypes.py b/tests/datatypes/test_datatypes.py index a6be0ac2..eaff8aff 100644 --- a/tests/datatypes/test_datatypes.py +++ b/tests/datatypes/test_datatypes.py @@ -99,24 +99,24 @@ def test_recorddesc_3(self): o = f(1, 2, 3, 4) desc = private.get_object_descriptor(o) - self.assertEqual(set(dir(desc)), set(('id', 'lb', 'c', 'd'))) + self.assertEqual(set(dir(desc)), set(('id', '@lb', 'c', 'd'))) - self.assertTrue(desc.is_linkprop('lb')) + self.assertTrue(desc.is_linkprop('@lb')) self.assertFalse(desc.is_linkprop('id')) self.assertFalse(desc.is_linkprop('c')) self.assertFalse(desc.is_linkprop('d')) - self.assertFalse(desc.is_link('lb')) + self.assertFalse(desc.is_link('@lb')) self.assertFalse(desc.is_link('id')) self.assertFalse(desc.is_link('c')) self.assertTrue(desc.is_link('d')) - self.assertFalse(desc.is_implicit('lb')) + self.assertFalse(desc.is_implicit('@lb')) self.assertTrue(desc.is_implicit('id')) self.assertFalse(desc.is_implicit('c')) self.assertFalse(desc.is_implicit('d')) - self.assertEqual(desc.get_pos('lb'), 1) + self.assertEqual(desc.get_pos('@lb'), 1) self.assertEqual(desc.get_pos('id'), 0) self.assertEqual(desc.get_pos('c'), 2) self.assertEqual(desc.get_pos('d'), 3) @@ -509,7 +509,7 @@ def test_object_1(self): with self.assertRaises(TypeError): len(o) - with self.assertRaises(KeyError): + with self.assertRaises(TypeError): o[0] with self.assertRaises(TypeError): @@ -681,9 +681,9 @@ def test_object_links_4(self): u = User(1, None) with self.assertRaisesRegex( - KeyError, "link property 'error_key' does not exist" + KeyError, "link property '@error_key' does not exist" ): - u['error_key'] + u['@error_key'] def test_object_link_property_1(self): O2 = private.create_object_factory( @@ -743,6 +743,7 @@ def test_object_dataclass_1(self): name='property', tuple='property', namedtuple='property', + linkprop="link-property", ) u = User( @@ -750,6 +751,7 @@ def test_object_dataclass_1(self): 'Bob', edgedb.Tuple((1, 2.0, '3')), edgedb.NamedTuple(a=1, b="Y"), + 123, ) self.assertTrue(dataclasses.is_dataclass(u)) self.assertEqual( diff --git a/tests/test_async_query.py b/tests/test_async_query.py index cb382ab8..af334216 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -992,3 +992,35 @@ async def test_async_banned_transaction(self): edgedb.CapabilityError, r'cannot execute transaction control commands'): await self.client.execute('start transaction') + + async def test_dup_link_prop_name(self): + obj = await self.client.query_single(''' + CREATE TYPE test::dup_link_prop_name { + CREATE PROPERTY val -> str; + }; + CREATE TYPE test::dup_link_prop_name_p { + CREATE LINK l -> test::dup_link_prop_name { + CREATE PROPERTY val -> int32; + } + }; + INSERT test::dup_link_prop_name_p { + l := (INSERT test::dup_link_prop_name { + val := "hello", + @val := 42, + }) + }; + SELECT test::dup_link_prop_name_p { + l: { + val, + @val + } + } LIMIT 1; + ''') + + self.assertEqual(obj.l.val, "hello") + self.assertEqual(obj.l["@val"], 42) + + await self.client.execute(''' + DROP TYPE test::dup_link_prop_name_p; + DROP TYPE test::dup_link_prop_name; + ''')