Skip to content

Commit

Permalink
Merge pull request #157 from arunsureshkumar/feat-pagination-performance
Browse files Browse the repository at this point in the history
Feat pagination performance
  • Loading branch information
abawchen committed Nov 14, 2020
2 parents 5d08133 + 863ab88 commit 9c2f8a8
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 123 deletions.
209 changes: 162 additions & 47 deletions graphene_mongo/converter.py

Large diffs are not rendered by default.

169 changes: 117 additions & 52 deletions graphene_mongo/fields.py
Expand Up @@ -7,6 +7,7 @@
import mongoengine
from bson import DBRef
from graphene import Context
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import ResolveInfo
from promise import Promise
Expand All @@ -15,7 +16,8 @@
from graphene.types.argument import to_arguments
from graphene.types.dynamic import Dynamic
from graphene.types.structures import Structure
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from graphql_relay.connection.arrayconnection import cursor_to_offset
from mongoengine import QuerySet

from .advanced_types import (
FileFieldType,
Expand All @@ -25,7 +27,8 @@
)
from .converter import convert_mongoengine_field, MongoEngineConversionError
from .registry import get_global_registry
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
connection_from_iterables


class MongoengineConnectionField(ConnectionField):
Expand Down Expand Up @@ -64,10 +67,8 @@ def order_by(self):
return self.node_type._meta.order_by

@property
def only_fields(self):
if isinstance(self.node_type._meta.only_fields, str):
return self.node_type._meta.only_fields.split(",")
return list()
def required_fields(self):
return tuple(set(self.node_type._meta.required_fields + self.node_type._meta.only_fields))

@property
def registry(self):
Expand Down Expand Up @@ -118,11 +119,13 @@ def is_filterable(k):
),
):
return False
if getattr(converted, "type", None) and getattr(converted.type, "_of_type", None) and issubclass(
(get_type(converted.type.of_type)), graphene.Union):
return False
if isinstance(converted, (graphene.List)) and issubclass(
getattr(converted, "_of_type", None), graphene.Union
):
return False

return True

def get_filter_type(_type):
Expand Down Expand Up @@ -177,29 +180,35 @@ def get_reference_field(r, kv):
if callable(getattr(field, "get_type", None)):
_type = field.get_type()
if _type:
node = _type._type._meta
node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta
if "id" in node.fields and not issubclass(
node.model, (mongoengine.EmbeddedDocument,)
):
r.update({kv[0]: node.fields["id"]._type.of_type()})

return r

return reduce(get_reference_field, self.fields.items(), {})

@property
def fields(self):
self._type = get_type(self._type)
return self._type._meta.fields

def get_queryset(self, model, info, only_fields=list(), **args):
def get_queryset(self, model, info, required_fields=list(), skip=None, limit=None, reversed=False, **args):
if args:
reference_fields = get_model_reference_fields(self.model)
hydrated_references = {}
for arg_name, arg in args.copy().items():
if arg_name in reference_fields:
reference_obj = get_node_from_global_id(
reference_fields[arg_name], info, args.pop(arg_name)
)
if arg_name in reference_fields and not isinstance(arg,
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
try:
reference_obj = reference_fields[arg_name].document_type(pk=from_global_id(arg)[1])
except TypeError:
reference_obj = reference_fields[arg_name].document_type(pk=arg)
hydrated_references[arg_name] = reference_obj
elif arg_name == "id":
hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
args.update(hydrated_references)

if self._get_queryset:
Expand All @@ -208,72 +217,120 @@ def get_queryset(self, model, info, only_fields=list(), **args):
return queryset_or_filters
else:
args.update(queryset_or_filters)
if limit is not None:
if reversed:
order_by = ""
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
skip if skip else 0).limit(limit)
else:
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
skip if skip else 0).limit(limit)
elif skip is not None:
if reversed:
order_by = ""
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
skip)
else:
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
skip)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

return model.objects(**args).no_dereference().only(*only_fields).order_by(self.order_by)

def default_resolver(self, _root, info, only_fields=list(), **args):
def default_resolver(self, _root, info, required_fields=list(), **args):
args = args or {}

if _root is not None:
field_name = to_snake_case(info.field_name)
if getattr(_root, field_name, []) is not None:
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]

connection_args = {
"first": args.pop("first", None),
"last": args.pop("last", None),
"before": args.pop("before", None),
"after": args.pop("after", None),
}
if field_name in _root._fields_ordered:
if getattr(_root, field_name, []) is not None:
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]

_id = args.pop('id', None)

if _id is not None:
args['pk'] = from_global_id(_id)[-1]

iterables = []
list_length = 0
skip = 0
count = 0
limit = None
reverse = False
if callable(getattr(self.model, "objects", None)):
iterables = self.get_queryset(self.model, info, only_fields, **args)
if isinstance(info, ResolveInfo):
if not info.context:
info.context = Context()
info.context.queryset = iterables
list_length = iterables.count()
else:
iterables = []
list_length = 0

connection = connection_from_list_slice(
list_slice=iterables,
args=connection_args,
list_length=list_length,
list_slice_length=list_length,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=graphene.PageInfo,
)
first = args.pop("first", None)
after = cursor_to_offset(args.pop("after", None))
last = args.pop("last", None)
before = cursor_to_offset(args.pop("before", None))
if "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
count=count)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
else:
args["pk__in"] = args["pk__in"][skip:skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
list_length = len(iterables)
if isinstance(info, ResolveInfo):
if not info.context:
info.context = Context()
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
elif _root is None:
count = self.get_queryset(self.model, info, required_fields, **args).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
count=count)
iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args)
list_length = len(iterables)
if isinstance(info, ResolveInfo):
if not info.context:
info.context = Context()
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
has_previous_page = True if skip else False
if reverse:
iterables = list(iterables)
iterables.reverse()
skip = limit
connection = connection_from_iterables(edges=iterables, start_offset=skip,
has_previous_page=has_previous_page,
has_next_page=has_next_page,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=graphene.PageInfo)

connection.iterable = iterables
connection.list_length = list_length
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):
only_fields = list()
for field in self.only_fields:
required_fields = list()
for field in self.required_fields:
if field in self.model._fields_ordered:
only_fields.append(field)
required_fields.append(field)
for field in get_query_fields(info):
if to_snake_case(field) in self.model._fields_ordered:
only_fields.append(to_snake_case(field))
required_fields.append(to_snake_case(field))
if not bool(args) or not is_partial:
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered:
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if isinstance(info, ResolveInfo):
if not info.context:
info.context = Context()
info.context.queryset = self.get_queryset(self.model, info, only_fields, **args_copy)
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)
# XXX: Filter nested args
resolved = resolver(root, info, **args)
if resolved is not None:
Expand All @@ -282,9 +339,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
return resolved
elif not isinstance(resolved[0], DBRef):
return resolved
elif isinstance(resolved, QuerySet):
args.update(resolved._query)
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
self.filter_args.keys()):
args_copy.pop(arg_name)
return self.default_resolver(root, info, required_fields, **args_copy)
else:
return resolved
return self.default_resolver(root, info, only_fields, **args)
return self.default_resolver(root, info, required_fields, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
Expand Down
2 changes: 2 additions & 0 deletions graphene_mongo/registry.py
@@ -1,6 +1,7 @@
class Registry(object):
def __init__(self):
self._registry = {}
self._registry_string_map = {}

def register(self, cls):
from .types import MongoengineObjectType
Expand All @@ -12,6 +13,7 @@ def register(self, cls):
)
assert cls._meta.registry == self, "Registry for a Model have to match."
self._registry[cls._meta.model] = cls
self._registry_string_map[cls.__name__] = cls._meta.model.__name__

# Rescan all fields
for model, cls in self._registry.items():
Expand Down

0 comments on commit 9c2f8a8

Please sign in to comment.