Skip to content

Commit

Permalink
Merge 827fbec into 7ca0925
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsureshkumar committed Oct 31, 2020
2 parents 7ca0925 + 827fbec commit 83f1b76
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/flask_mongoengine/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mongoengine import connect

from models import Department, Employee, Role, Task
from .models import Department, Employee, Role, Task

connect("graphene-mongo-example", host="mongomock://localhost", alias="default")

Expand Down
16 changes: 12 additions & 4 deletions examples/flask_mongoengine/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import graphene
from graphene.relay import Node
from graphene_mongo.tests.nodes import PlayerNode, ReporterNode

from graphene_mongo import MongoengineConnectionField, MongoengineObjectType
from models import Department as DepartmentModel
from models import Employee as EmployeeModel
from models import Role as RoleModel
from models import Task as TaskModel
from .models import Department as DepartmentModel
from .models import Employee as EmployeeModel
from .models import Role as RoleModel
from .models import Task as TaskModel


class Department(MongoengineObjectType):
Expand All @@ -17,6 +19,9 @@ class Role(MongoengineObjectType):
class Meta:
model = RoleModel
interfaces = (Node,)
filter_fields = {
'name': ['exact', 'icontains', 'istartswith']
}


class Task(MongoengineObjectType):
Expand All @@ -29,6 +34,9 @@ class Employee(MongoengineObjectType):
class Meta:
model = EmployeeModel
interfaces = (Node,)
filter_fields = {
'name': ['exact', 'icontains', 'istartswith']
}


class Query(graphene.ObjectType):
Expand Down
88 changes: 81 additions & 7 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import uuid

from graphene.types.json import JSONString
from graphene.utils.str_converters import to_snake_case
from mongoengine.base import get_document

from . import advanced_types
from .utils import import_single_dispatch, get_field_description
from .utils import import_single_dispatch, get_field_description, get_query_fields
from concurrent.futures import ThreadPoolExecutor, as_completed

singledispatch = import_single_dispatch()

Expand Down Expand Up @@ -104,6 +105,46 @@ def convert_file_to_field(field, registry=None):
def convert_field_to_list(field, registry=None):
base_type = convert_mongoengine_field(field.field, registry=registry)
if isinstance(base_type, graphene.Field):
if isinstance(field.field, mongoengine.GenericReferenceField):
def get_reference_objects(*args, **kwargs):
if args[0][1]:
document = get_document(args[0][0])
document_field = mongoengine.ReferenceField(document)
document_field = convert_mongoengine_field(document_field, registry)
document_field_type = document_field.get_type().type._meta.name
only_fields = [to_snake_case(i) for i in get_query_fields(args[0][3][0])[document_field_type].keys()]
return document.objects().no_dereference().only(*only_fields).filter(pk__in=args[0][1])
else:
return []

def reference_resolver(root, *args, **kwargs):
choice_to_resolve = dict()
to_resolve = getattr(root, field.name or field.db_name)
for each in to_resolve:
if each['_cls'] not in choice_to_resolve:
choice_to_resolve[each['_cls']] = list()
choice_to_resolve[each['_cls']].append(each["_ref"].id)

pool = ThreadPoolExecutor(5)
futures = list()
for model, object_id_list in choice_to_resolve.items():
futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args)))
result = list()
for x in as_completed(futures):
result += x.result()
to_resolve_object_ids = [each["_ref"].id for each in to_resolve]
result_to_resolve_object_ids = [each.id for each in result]
ordered_result = list()
for each in to_resolve_object_ids:
ordered_result.append(result[result_to_resolve_object_ids.index(each)])
return ordered_result

return graphene.List(
base_type._type,
description=get_field_description(field, registry),
required=field.required,
resolver=reference_resolver
)
return graphene.List(
base_type._type,
description=get_field_description(field, registry),
Expand All @@ -121,7 +162,7 @@ def convert_field_to_list(field, registry=None):
# Non-relationship field
relations = (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField)
if not isinstance(base_type, (graphene.List, graphene.NonNull)) and not isinstance(
field.field, relations
field.field, relations
):
base_type = type(base_type)

Expand All @@ -135,7 +176,6 @@ def convert_field_to_list(field, registry=None):
@convert_mongoengine_field.register(mongoengine.GenericEmbeddedDocumentField)
@convert_mongoengine_field.register(mongoengine.GenericReferenceField)
def convert_field_to_union(field, registry=None):

_types = []
for choice in field.choices:
if isinstance(field, mongoengine.GenericReferenceField):
Expand All @@ -162,6 +202,20 @@ def convert_field_to_union(field, registry=None):
)
Meta = type("Meta", (object,), {"types": tuple(_types)})
_union = type(name, (graphene.Union,), {"Meta": Meta})

def reference_resolver(root, *args, **kwargs):
dereferenced = getattr(root, field.name or field.db_name)
document = get_document(dereferenced["_cls"])
document_field = mongoengine.ReferenceField(document)
document_field = convert_mongoengine_field(document_field, registry)
document_field_type = document_field.get_type().type._meta.name
only_fields = [to_snake_case(i) for i in get_query_fields(args[0])[document_field_type].keys()]
return document.objects().no_dereference().only(*only_fields).get(pk=dereferenced["_ref"].id)

if isinstance(field, mongoengine.GenericReferenceField):
return graphene.Field(_union, resolver=reference_resolver,
description=get_field_description(field, registry))

return graphene.Field(_union)


Expand All @@ -171,11 +225,30 @@ def convert_field_to_union(field, registry=None):
def convert_field_to_dynamic(field, registry=None):
model = field.document_type

def reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
only_fields = [to_snake_case(i) for i in get_query_fields(args[0]).keys()]
return field.document_type.objects().no_dereference().only(*only_fields).get(pk=document.id)
return None

def cached_reference_resolver(root, *args, **kwargs):
only_fields = [to_snake_case(i) for i in get_query_fields(args[0]).keys()]
return field.document_type.objects().no_dereference().only(*only_fields).get(
pk=getattr(root, field.name or field.db_name))

def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return None
return graphene.Field(_type, description=get_field_description(field, registry))
elif isinstance(field, mongoengine.ReferenceField):
return graphene.Field(_type, resolver=reference_resolver,
description=get_field_description(field, registry))
elif isinstance(field, mongoengine.CachedReferenceField):
return graphene.Field(_type, resolver=cached_reference_resolver,
description=get_field_description(field, registry))
return graphene.Field(_type,
description=get_field_description(field, registry))

return graphene.Dynamic(dynamic_type)

Expand All @@ -185,8 +258,9 @@ def convert_lazy_field_to_dynamic(field, registry=None):
model = field.document_type

def lazy_resolver(root, *args, **kwargs):
if getattr(root, field.name or field.db_name):
return getattr(root, field.name or field.db_name).fetch()
document = getattr(root, field.name or field.db_name)
only_fields = [to_snake_case(i) for i in get_query_fields(args[0]).keys()]
return document.document_type.objects().no_dereference().only(*only_fields).get(pk=document.pk)

def dynamic_type():
_type = registry.get_type_for_model(model)
Expand Down
76 changes: 54 additions & 22 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import graphene
import mongoengine
from bson import DBRef
from graphene import Context
from graphene.utils.str_converters import to_snake_case
from promise import Promise
from graphql_relay import from_global_id
from graphene.relay import ConnectionField
Expand All @@ -21,7 +24,7 @@
)
from .converter import convert_mongoengine_field, MongoEngineConversionError
from .registry import get_global_registry
from .utils import get_model_reference_fields, get_node_from_global_id
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields


class MongoengineConnectionField(ConnectionField):
Expand Down Expand Up @@ -59,6 +62,12 @@ def model(self):
def order_by(self):
return self.node_type._meta.order_by

@property
def only_fields(self):
if isinstance(self.node_type._meta.only_fields, str):
return self.node_type._meta.only_fields.split(",")
return list()

@property
def registry(self):
return getattr(self.node_type._meta, "registry", get_global_registry())
Expand Down Expand Up @@ -98,18 +107,18 @@ def is_filterable(k):
if isinstance(converted, (ConnectionField, Dynamic)):
return False
if callable(getattr(converted, "type", None)) and isinstance(
converted.type(),
(
FileFieldType,
PointFieldType,
MultiPolygonFieldType,
graphene.Union,
PolygonFieldType,
),
converted.type(),
(
FileFieldType,
PointFieldType,
MultiPolygonFieldType,
graphene.Union,
PolygonFieldType,
),
):
return False
if isinstance(converted, (graphene.List)) and issubclass(
getattr(converted, "_of_type", None), graphene.Union
getattr(converted, "_of_type", None), graphene.Union
):
return False

Expand Down Expand Up @@ -160,16 +169,16 @@ def get_reference_field(r, kv):
field = kv[1]
mongo_field = getattr(self.model, kv[0], None)
if isinstance(
mongo_field,
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
mongo_field,
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
):
field = convert_mongoengine_field(mongo_field, self.registry)
if callable(getattr(field, "get_type", None)):
_type = field.get_type()
if _type:
node = _type._type._meta
if "id" in node.fields and not issubclass(
node.model, (mongoengine.EmbeddedDocument,)
node.model, (mongoengine.EmbeddedDocument,)
):
r.update({kv[0]: node.fields["id"]._type.of_type()})
return r
Expand All @@ -180,7 +189,7 @@ def get_reference_field(r, kv):
def fields(self):
return self._type._meta.fields

def get_queryset(self, model, info, **args):
def get_queryset(self, model, info, only_fields=list(), **args):
if args:
reference_fields = get_model_reference_fields(self.model)
hydrated_references = {}
Expand All @@ -198,13 +207,16 @@ def get_queryset(self, model, info, **args):
return queryset_or_filters
else:
args.update(queryset_or_filters)
return model.objects(**args).order_by(self.order_by)

def default_resolver(self, _root, info, **args):
return model.objects(**args).no_dereference().only(*only_fields).order_by(self.order_by)

def default_resolver(self, _root, info, only_fields=list(), **args):
args = args or {}

if _root is not None:
args["pk__in"] = [r.pk for r in getattr(_root, info.field_name, [])]
field_name = to_snake_case(info.field_name)
if getattr(_root, field_name, []) is not None:
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]

connection_args = {
"first": args.pop("first", None),
Expand All @@ -219,7 +231,7 @@ def default_resolver(self, _root, info, **args):
args['pk'] = from_global_id(_id)[-1]

if callable(getattr(self.model, "objects", None)):
iterables = self.get_queryset(self.model, info, **args)
iterables = self.get_queryset(self.model, info, only_fields, **args)
list_length = iterables.count()
else:
iterables = []
Expand All @@ -239,23 +251,43 @@ def default_resolver(self, _root, info, **args):
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):
only_fields = list()
for field in self.only_fields:
if field in self.model._fields_ordered:
only_fields.append(field)
for field in get_query_fields(info):
if to_snake_case(field) in self.model._fields_ordered:
only_fields.append(to_snake_case(field))
if not bool(args) or not is_partial:
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered:
args_copy.pop(arg_name)
if not info.context:
info.context = Context()
info.context.queryset = self.get_queryset(self.model, info, only_fields, **args_copy)
# XXX: Filter nested args
resolved = resolver(root, info, **args)
if resolved is not None:
return resolved
return self.default_resolver(root, info, **args)
if isinstance(resolved, list):
if resolved == list():
return resolved
elif not isinstance(resolved[0], DBRef):
return resolved
else:
return resolved
return self.default_resolver(root, info, only_fields, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
iterable = resolver(root, info, **args)
if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)

def get_resolver(self, parent_resolver):
Expand Down
2 changes: 1 addition & 1 deletion graphene_mongo/tests/test_relay_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Query(graphene.ObjectType):
reporter = graphene.Field(nodes.ReporterNode)

def resolve_reporter(self, *args, **kwargs):
return models.Reporter.objects.first()
return models.Reporter.objects.no_dereference().first()

query = """
query ReporterQuery {
Expand Down
Loading

0 comments on commit 83f1b76

Please sign in to comment.