Skip to content

Commit

Permalink
Implement dataclass for EdgeObject (#359)
Browse files Browse the repository at this point in the history
Co-authored-by: Yury Selivanov <yury@edgedb.com>
  • Loading branch information
fantix and 1st1 committed Sep 4, 2022
1 parent 241c80d commit dfb8c8b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 2 deletions.
2 changes: 2 additions & 0 deletions edgedb/datatypes/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ typedef struct {
EdgeRecordFieldDesc *descs;
Py_ssize_t idpos;
Py_ssize_t size;
PyObject *get_dataclass_fields_func;
} EdgeRecordDescObject;

typedef enum {
Expand All @@ -82,6 +83,7 @@ EdgeFieldCardinality EdgeRecordDesc_PointerCardinality(PyObject *, Py_ssize_t);
Py_ssize_t EdgeRecordDesc_GetSize(PyObject *);
edge_attr_lookup_t EdgeRecordDesc_Lookup(PyObject *, PyObject *, Py_ssize_t *);
PyObject * EdgeRecordDesc_List(PyObject *, uint8_t, uint8_t);
PyObject * EdgeRecordDesc_GetDataclassFields(PyObject *);



Expand Down
22 changes: 21 additions & 1 deletion edgedb/datatypes/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,18 @@ object_getattr(EdgeObject *o, PyObject *name)
case L_ERROR:
return NULL;

case L_LINKPROP:
case L_NOT_FOUND:
// Used in `dataclasses.as_dict()`
if (
PyUnicode_CompareWithASCIIString(
name, "__dataclass_fields__"
) == 0
) {
return EdgeRecordDesc_GetDataclassFields((PyObject *)o->desc);
}
return PyObject_GenericGetAttr((PyObject *)o, name);

case L_LINKPROP:
return PyObject_GenericGetAttr((PyObject *)o, name);

case L_LINK:
Expand Down Expand Up @@ -365,6 +375,16 @@ EdgeObject_InitType(void)
return NULL;
}

// Pass the `dataclasses.is_dataclass(obj)` check - which then checks
// `hasattr(type(obj), "__dataclass_fields__")`, the dict is always empty
PyObject *default_fields = PyDict_New();
if (default_fields == NULL) {
return NULL;
}
PyDict_SetItemString(
EdgeObject_Type.tp_dict, "__dataclass_fields__", default_fields
);

base_hash = _EdgeGeneric_HashString("edgedb.Object");
if (base_hash == -1) {
return NULL;
Expand Down
33 changes: 33 additions & 0 deletions edgedb/datatypes/record_desc.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ record_desc_dealloc(EdgeRecordDescObject *o)
PyObject_GC_UnTrack(o);
Py_CLEAR(o->index);
Py_CLEAR(o->names);
Py_CLEAR(o->get_dataclass_fields_func);
PyMem_RawFree(o->descs);
PyObject_GC_Del(o);
}
Expand Down Expand Up @@ -177,12 +178,24 @@ record_desc_dir(EdgeRecordDescObject *o, PyObject *args)
}


static PyObject *
record_set_dataclass_fields_func(EdgeRecordDescObject *o, PyObject *arg)
{
Py_CLEAR(o->get_dataclass_fields_func);
o->get_dataclass_fields_func = arg;
Py_INCREF(arg);
Py_RETURN_NONE;
}


static PyMethodDef record_desc_methods[] = {
{"is_linkprop", (PyCFunction)record_desc_is_linkprop, METH_O, NULL},
{"is_link", (PyCFunction)record_desc_is_link, METH_O, NULL},
{"is_implicit", (PyCFunction)record_desc_is_implicit, METH_O, NULL},
{"get_pos", (PyCFunction)record_desc_get_pos, METH_O, NULL},
{"__dir__", (PyCFunction)record_desc_dir, METH_NOARGS, NULL},
{"set_dataclass_fields_func",
(PyCFunction)record_set_dataclass_fields_func, METH_O, NULL},
{NULL, NULL}
};

Expand Down Expand Up @@ -349,6 +362,7 @@ EdgeRecordDesc_New(PyObject *names, PyObject *flags, PyObject *cards)

o->size = size;
o->idpos = idpos;
o->get_dataclass_fields_func = NULL;

PyObject_GC_Track(o);
return (PyObject *)o;
Expand Down Expand Up @@ -537,6 +551,25 @@ EdgeRecordDesc_List(PyObject *ob, uint8_t include_mask, uint8_t exclude_mask)
}


PyObject *
EdgeRecordDesc_GetDataclassFields(PyObject *ob)
{
if (!EdgeRecordDesc_Check(ob)) {
PyErr_BadInternalCall();
return NULL;
}

EdgeRecordDescObject *o = (EdgeRecordDescObject *)ob;

// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1
#if PY_VERSION_HEX < 0x030900A1
return PyObject_CallFunctionObjArgs(o->get_dataclass_fields_func, NULL);
#else
return PyObject_CallNoArgs(o->get_dataclass_fields_func);
#endif
}


PyObject *
EdgeRecordDesc_InitType(void)
{
Expand Down
4 changes: 3 additions & 1 deletion edgedb/protocol/codecs/object.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

@cython.final
cdef class ObjectCodec(BaseNamedRecordCodec):
cdef bint is_sparse
cdef:
bint is_sparse
object cached_dataclass_fields

cdef encode_args(self, WriteBuffer buf, dict obj)

Expand Down
19 changes: 19 additions & 0 deletions edgedb/protocol/codecs/object.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# limitations under the License.
#

import dataclasses


@cython.final
cdef class ObjectCodec(BaseNamedRecordCodec):
Expand Down Expand Up @@ -180,6 +182,22 @@ cdef class ObjectCodec(BaseNamedRecordCodec):

return result

def get_dataclass_fields(self):
cdef descriptor = (<BaseNamedRecordCodec>self).descriptor

rv = self.cached_dataclass_fields
if rv is None:
rv = {}

for i in range(len(self.fields_codecs)):
name = datatypes.record_desc_pointer_name(descriptor, i)
field = rv[name] = dataclasses.field()
field.name = name
field._field_type = dataclasses._FIELD

self.cached_dataclass_fields = rv
return rv

@staticmethod
cdef BaseCodec new(bytes tid, tuple names, tuple flags, tuple cards,
tuple codecs, bint is_sparse):
Expand All @@ -195,6 +213,7 @@ cdef class ObjectCodec(BaseNamedRecordCodec):
codec.name = 'Object'
codec.is_sparse = is_sparse
codec.descriptor = datatypes.record_desc_new(names, flags, cards)
codec.descriptor.set_dataclass_fields_func(codec.get_dataclass_fields)
codec.fields_codecs = codecs

return codec

0 comments on commit dfb8c8b

Please sign in to comment.