Skip to content

Commit

Permalink
Merge pull request #69 from zrlay/master
Browse files Browse the repository at this point in the history
features: ConnectionField customization, LazyReferences, better field descriptions
  • Loading branch information
abawchen committed Feb 12, 2019
2 parents a1de06c + 349d5de commit c98b33a
Show file tree
Hide file tree
Showing 12 changed files with 444 additions and 139 deletions.
46 changes: 33 additions & 13 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import mongoengine

from .advanced_types import PointFieldType, MultiPolygonFieldType
from .fields import MongoengineConnectionField
from .utils import import_single_dispatch
from .utils import import_single_dispatch, get_field_description

singledispatch = import_single_dispatch()


class MongoEngineConversionError(Exception):
pass


@singledispatch
def convert_mongoengine_field(field, registry=None):
raise Exception(
raise MongoEngineConversionError(
"Don't know how to convert the MongoEngine field %s (%s)" %
(field, field.__class__))

Expand All @@ -33,36 +36,36 @@ def convert_mongoengine_field(field, registry=None):
@convert_mongoengine_field.register(mongoengine.StringField)
@convert_mongoengine_field.register(mongoengine.URLField)
def convert_field_to_string(field, registry=None):
return String(description=field.db_field, required=field.required)
return String(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.UUIDField)
@convert_mongoengine_field.register(mongoengine.ObjectIdField)
def convert_field_to_id(field, registry=None):
return ID(description=field.db_field, required=field.required)
return ID(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.IntField)
@convert_mongoengine_field.register(mongoengine.LongField)
def convert_field_to_int(field, registry=None):
return Int(description=field.db_field, required=field.required)
return Int(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.BooleanField)
def convert_field_to_boolean(field, registry=None):
return Boolean(description=field.db_field, required=field.required)
return Boolean(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.DecimalField)
@convert_mongoengine_field.register(mongoengine.FloatField)
def convert_field_to_float(field, registry=None):
return Float(description=field.db_field, required=field.required)
return Float(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.DictField)
@convert_mongoengine_field.register(mongoengine.MapField)
def convert_dict_to_jsonstring(field, registry=None):
return JSONString(description=field.db_field, required=field.required)
return JSONString(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.PointField)
Expand All @@ -77,7 +80,7 @@ def convert_multipolygon_to_field(field, register=None):

@convert_mongoengine_field.register(mongoengine.DateTimeField)
def convert_field_to_datetime(field, registry=None):
return DateTime(description=field.db_field, required=field.required)
return DateTime(description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.ListField)
Expand All @@ -91,15 +94,15 @@ def convert_field_to_list(field, registry=None):
base_type = base_type._type

if is_node(base_type):
return MongoengineConnectionField(base_type)
return base_type._meta.connection_field_class(base_type)

# Non-relationship field
relations = (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField)
if not isinstance(base_type, (List, NonNull)) \
and not isinstance(field.field, relations):
base_type = type(base_type)

return List(base_type, description=field.db_field, required=field.required)
return List(base_type, description=get_field_description(field, registry), required=field.required)


@convert_mongoengine_field.register(mongoengine.EmbeddedDocumentField)
Expand All @@ -111,6 +114,23 @@ def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return None
return Field(_type)
return Field(_type, description=get_field_description(field, registry))

return Dynamic(dynamic_type)


@convert_mongoengine_field.register(mongoengine.LazyReferenceField)
def convert_lazy_field_to_dynamic(field, registry=None):
model = field.document_type

def lazy_resolver(root, *args, **kwargs):
if getattr(root, field.name or field.db_name):
return getattr(root, field.name or field.db_name).fetch()

def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return None
return Field(_type, resolver=lazy_resolver, description=get_field_description(field, registry))

return Dynamic(dynamic_type)
178 changes: 87 additions & 91 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from __future__ import absolute_import

import mongoengine
from collections import OrderedDict
from functools import partial, reduce

import mongoengine
from graphene import PageInfo
from graphene.relay import ConnectionField
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from graphql_relay.node.node import from_global_id
from graphene.types.argument import to_arguments
from graphene.types.dynamic import Dynamic
from graphene.types.structures import Structure
from graphene.types.structures import Structure, List
from graphql_relay.connection.arrayconnection import connection_from_list_slice

from .advanced_types import PointFieldType, MultiPolygonFieldType
from .utils import get_model_reference_fields
from .converter import convert_mongoengine_field, MongoEngineConversionError
from .registry import get_global_registry
from .utils import get_model_reference_fields, get_node_from_global_id


class MongoengineConnectionField(ConnectionField):

def __init__(self, type, *args, **kwargs):
get_queryset = kwargs.pop('get_queryset', None)
if get_queryset:
assert callable(get_queryset), "Attribute `get_queryset` on {} must be callable.".format(self)
self._get_queryset = get_queryset
super(MongoengineConnectionField, self).__init__(
type,
*args,
Expand All @@ -43,6 +48,10 @@ def node_type(self):
def model(self):
return self.node_type._meta.model

@property
def registry(self):
return getattr(self.node_type._meta, 'registry', get_global_registry())

@property
def args(self):
return to_arguments(
Expand All @@ -55,12 +64,19 @@ def args(self, args):
self._base_args = args

def _field_args(self, items):
def is_filterable(v):
if isinstance(v, (ConnectionField, Dynamic)):
def is_filterable(k):
if not hasattr(self.model, k):
return False
if isinstance(getattr(self.model, k), property):
return False
# FIXME: Skip PointTypeField at this moment.
if not isinstance(v.type, Structure) \
and isinstance(v.type(), (PointFieldType, MultiPolygonFieldType)):
try:
converted = convert_mongoengine_field(getattr(self.model, k), self.registry)
except MongoEngineConversionError:
return False
if isinstance(converted, (ConnectionField, Dynamic, List)):
return False
if callable(getattr(converted, 'type', None)) and isinstance(converted.type(),
(PointFieldType, MultiPolygonFieldType)):
return False
return True

Expand All @@ -69,7 +85,7 @@ def get_type(v):
return v.type.of_type()
return v.type()

return {k: get_type(v) for k, v in items if is_filterable(v)}
return {k: get_type(v) for k, v in items if is_filterable(k)}

@property
def field_args(self):
Expand All @@ -78,102 +94,82 @@ def field_args(self):
@property
def reference_args(self):
def get_reference_field(r, kv):
if callable(getattr(kv[1], 'get_type', None)):
node = kv[1].get_type()._type._meta
if not issubclass(node.model, mongoengine.EmbeddedDocument):
r.update({kv[0]: node.fields['id']._type.of_type()})
field = kv[1]
mongo_field = getattr(self.model, kv[0], None)
if isinstance(mongo_field, (mongoengine.LazyReferenceField, mongoengine.ReferenceField)):
field = convert_mongoengine_field(mongo_field, self.registry)
if callable(getattr(field, 'get_type', None)):
_type = field.get_type()
if _type:
node = _type._type._meta
if 'id' in node.fields and not issubclass(node.model, mongoengine.EmbeddedDocument):
r.update({kv[0]: node.fields['id']._type.of_type()})
return r

return reduce(get_reference_field, self.fields.items(), {})

@property
def fields(self):
return self._type._meta.fields

@classmethod
def get_query(cls, model, info, **args):
def get_queryset(self, model, info, **args):

if not callable(getattr(model, 'objects', None)):
return [], 0

objs = model.objects()
if args:
reference_fields = get_model_reference_fields(model)
reference_args = {}
reference_fields = get_model_reference_fields(self.model)
hydrated_references = {}
for arg_name, arg in args.copy().items():
if arg_name in reference_fields:
reference_model = model._fields[arg_name]
pk = from_global_id(args.pop(arg_name))[-1]
reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
reference_args[arg_name] = reference_obj

args.update(reference_args)
first = args.pop('first', None)
last = args.pop('last', None)
id = args.pop('id', None)
before = args.pop('before', None)
after = args.pop('after', None)

if id is not None:
# https://github.com/graphql-python/graphene/issues/124
args['pk'] = from_global_id(id)[-1]

objs = objs.filter(**args)

# https://github.com/graphql-python/graphene-mongo/issues/21
if after is not None:
_after = int(from_global_id(after)[-1])
objs = objs[_after:]

if before is not None:
_before = int(from_global_id(before)[-1])
objs = objs[:_before]

reference_obj = get_node_from_global_id(reference_fields[arg_name], info, args.pop(arg_name))
hydrated_references[arg_name] = reference_obj
args.update(hydrated_references)
if self._get_queryset:
queryset_or_filters = self._get_queryset(model, info, **args)
if isinstance(queryset_or_filters, mongoengine.QuerySet):
return queryset_or_filters
else:
args.update(queryset_or_filters)
return model.objects(**args)

def default_resolver(self, _root, info, **args):
args = args or {}

connection_args = {
'first': args.pop('first', None),
'last': args.pop('last', None),
'before': args.pop('before', None),
'after': args.pop('after', None)
}

_id = args.pop('id', None)

if _id is not None:
objs = [get_node_from_global_id(self.node_type, info, _id)]
list_length = 1
elif callable(getattr(self.model, 'objects', None)):
objs = self.get_queryset(self.model, info, **args)
list_length = objs.count()

if first is not None:
objs = objs[:first]
if last is not None:
# https://github.com/graphql-python/graphene-mongo/issues/20
objs = objs[max(0, list_length - last):]
else:
list_length = objs.count()

return objs, list_length

# noqa
@classmethod
def merge_querysets(cls, default_queryset, queryset):
return queryset & default_queryset

"""
Notes: Not sure how does this work :(
"""
@classmethod
def connection_resolver(cls, resolver, connection, model, root, info, **args):
iterable = resolver(root, info, **args)

if not iterable:
iterable, _len = cls.get_query(model, info, **args)

if root:
# If we have a root, we must be at least 1 layer in, right?
_len = 0
else:
_len = len(iterable)
objs = []
list_length = 0

connection = connection_from_list_slice(
iterable,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
list_slice=objs,
args=connection_args,
list_length=list_length,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=PageInfo,
edge_type=connection.Edge,
)
connection.iterable = iterable
connection.length = _len
connection.iterable = objs
return connection

def chained_resolver(self, resolver, root, info, **args):
resolved = resolver(root, info, **args)
if resolved is not None:
return resolved
return self.default_resolver(root, info, **args)

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
super_resolver = self.resolver or parent_resolver
resolver = partial(self.chained_resolver, super_resolver)
return partial(self.connection_resolver, resolver, self.type)
Loading

0 comments on commit c98b33a

Please sign in to comment.