Skip to content

Commit

Permalink
Merge c95b914 into 97266e6
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsureshkumar committed Mar 29, 2021
2 parents 97266e6 + c95b914 commit 3f2c4c8
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 51 deletions.
2 changes: 1 addition & 1 deletion graphene_mongo/advanced_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def resolve_data(self, info):
v = getattr(self.instance, self.key)
data = v.read()
if data is not None:
return base64.b64encode(data)
return base64.b64encode(data).decode("utf-8")
return None


Expand Down
55 changes: 37 additions & 18 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from graphene import Context
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import ResolveInfo
from graphql import GraphQLResolveInfo
from mongoengine.base import get_document
from promise import Promise
from graphql_relay import from_global_id
Expand Down Expand Up @@ -44,12 +44,12 @@ def __init__(self, type, *args, **kwargs):

@property
def type(self):
from .types import GrapheneMongoengineObjectTypes
from .types import MongoengineObjectType

_type = super(ConnectionField, self).type
assert issubclass(
_type, GrapheneMongoengineObjectTypes
), "MongoengineConnectionField only accepts Mongoengine object types"
_type, MongoengineObjectType
), "MongoengineConnectionField only accepts MongoengineObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
Expand Down Expand Up @@ -79,7 +79,7 @@ def registry(self):
def args(self):
return to_arguments(
self._base_args or OrderedDict(),
dict(dict(self.field_args, **self.advance_args), **self.filter_args),
dict(dict(dict(self.field_args, **self.advance_args), **self.filter_args), **self.extended_args),
)

@args.setter
Expand All @@ -96,7 +96,8 @@ def is_filterable(k):
Returns:
bool
"""

if hasattr(self.fields[k].type, '_sdl'):
return False
if not hasattr(self.model, k):
return False
if isinstance(getattr(self.model, k), property):
Expand Down Expand Up @@ -167,7 +168,7 @@ def filter_args(self):
}
filter_type = advanced_filter_types.get(each, filter_type)
filter_args[field + "__" + each] = graphene.Argument(
type=filter_type
type_=filter_type
)
return filter_args

Expand Down Expand Up @@ -201,6 +202,14 @@ def get_advance_field(r, kv):

return reduce(get_advance_field, self.fields.items(), {})

@property
def extended_args(self):
args = OrderedDict()
for k, each in self.fields.items():
if hasattr(each.type, '_sdl'):
args.update({k: graphene.ID()})
return args

@property
def fields(self):
self._type = get_type(self._type)
Expand Down Expand Up @@ -269,6 +278,9 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non

def default_resolver(self, _root, info, required_fields=list(), **args):
args = args or {}
for key, value in dict(args).items():
if value is None:
del args[key]
if _root is not None:
field_name = to_snake_case(info.field_name)
if not hasattr(_root, "_fields_ordered"):
Expand All @@ -292,9 +304,13 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
limit = None
reverse = False
first = args.pop("first", None)
after = cursor_to_offset(args.pop("after", None))
after = args.pop("after", None)
if after:
after = cursor_to_offset(after)
last = args.pop("last", None)
before = cursor_to_offset(args.pop("before", None))
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
if callable(getattr(self.model, "objects", None)):
if "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
Expand All @@ -309,9 +325,9 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
list_length = len(iterables)
if isinstance(info, ResolveInfo):
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info.context = Context()
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
elif _root is None or args:
count = self.get_queryset(self.model, info, required_fields, **args).count()
Expand All @@ -320,9 +336,9 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
count=count)
iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args)
list_length = len(iterables)
if isinstance(info, ResolveInfo):
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info.context = Context()
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)

elif _root is not None:
Expand Down Expand Up @@ -358,6 +374,9 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):
for key, value in dict(args).items():
if value is None:
del args[key]
required_fields = list()
for field in self.required_fields:
if field in self.model._fields_ordered:
Expand All @@ -372,10 +391,11 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
for arg_name, arg in args.copy().items():
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if isinstance(info, ResolveInfo):
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info.context = Context()
info = info._replace(context=Context())
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)

# XXX: Filter nested args
resolved = resolver(root, info, **args)
if resolved is not None:
Expand All @@ -394,9 +414,6 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
self.filter_args.keys()):
args_copy.pop(arg_name)
if arg_name == '_id' and isinstance(arg, dict):
args_copy['pk__in'] = arg['$in']
elif "$ne" in arg:
args_copy['pk__ne'] = arg['$ne']
operation = list(arg.keys())[0]
args_copy['pk' + operation.replace('$', '__')] = arg[operation]
if '.' in arg_name:
Expand All @@ -409,6 +426,8 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
del args_copy[arg_name]
return self.default_resolver(root, info, required_fields, **args_copy)
elif isinstance(resolved, Promise):
return resolved.value
else:
return resolved
return self.default_resolver(root, info, required_fields, **args)
Expand Down
5 changes: 2 additions & 3 deletions graphene_mongo/tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ def mutate(self, info, article):
return CreateArticle(article=article)

class Query(graphene.ObjectType):

node = Node.Field()

class Mutation(graphene.ObjectType):

create_article = CreateArticle.Field()

query = """
Expand Down Expand Up @@ -57,7 +55,8 @@ class Arguments:
def mutate(self, info, id, editor):
editor_to_update = Editor.objects.get(id=id)
for key, value in editor.items():
setattr(editor_to_update, key, value)
if value:
setattr(editor_to_update, key, value)
editor_to_update.save()
return UpdateEditor(editor=editor_to_update)

Expand Down
2 changes: 1 addition & 1 deletion graphene_mongo/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def resolve_editors(self, *args, **kwargs):
"chunkSize": 261120,
"length": 46928,
"md5": "f3c657fd472fdc4bc2ca9056a1ae6106",
"data": str(data),
"data": data.decode("utf-8"),
},
},
"editors": [
Expand Down
8 changes: 1 addition & 7 deletions graphene_mongo/tests/test_relay_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class Query(graphene.ObjectType):
"avatar": {
"contentType": "image/jpeg",
"length": 46928,
"data": str(data),
"data": data.decode("utf-8"),
},
}
},
Expand Down Expand Up @@ -489,7 +489,6 @@ class Query(graphene.ObjectType):

def test_should_first_n(fixtures):
class Query(graphene.ObjectType):

editors = MongoengineConnectionField(nodes.EditorNode)

query = """
Expand Down Expand Up @@ -533,7 +532,6 @@ class Query(graphene.ObjectType):

def test_should_after(fixtures):
class Query(graphene.ObjectType):

players = MongoengineConnectionField(nodes.PlayerNode)

query = """
Expand Down Expand Up @@ -566,7 +564,6 @@ class Query(graphene.ObjectType):

def test_should_before(fixtures):
class Query(graphene.ObjectType):

players = MongoengineConnectionField(nodes.PlayerNode)

query = """
Expand Down Expand Up @@ -632,7 +629,6 @@ class Query(graphene.ObjectType):

def test_should_self_reference(fixtures):
class Query(graphene.ObjectType):

players = MongoengineConnectionField(nodes.PlayerNode)

query = """
Expand Down Expand Up @@ -767,7 +763,6 @@ class Query(graphene.ObjectType):

def test_should_query_with_embedded_document(fixtures):
class Query(graphene.ObjectType):

professors = MongoengineConnectionField(nodes.ProfessorVectorNode)

query = """
Expand Down Expand Up @@ -1026,7 +1021,6 @@ class Query(graphene.ObjectType):


def test_should_filter_mongoengine_queryset_by_id_and_other_fields(fixtures):

class Query(graphene.ObjectType):
players = MongoengineConnectionField(nodes.PlayerNode)

Expand Down
44 changes: 33 additions & 11 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import mongoengine
from graphene import Node
from graphene.utils.trim_docstring import trim_docstring
from graphql.utils.ast_to_dict import ast_to_dict
# from graphql.utils.ast_to_dict import ast_to_dict
from graphql import FieldNode
from graphql_relay.connection.arrayconnection import offset_to_cursor


Expand Down Expand Up @@ -126,21 +127,24 @@ def collect_query_fields(node, fragments):
"""

field = {}

if node.get('selection_set'):
for leaf in node['selection_set']['selections']:
if leaf['kind'] == 'Field':
selection_set = None
if type(node) == dict:
selection_set = node.get('selection_set')
else:
selection_set = node.selection_set
if selection_set:
for leaf in selection_set.selections:
if leaf.kind == 'field':
field.update({
leaf['name']['value']: collect_query_fields(leaf, fragments)
leaf.name.value: collect_query_fields(leaf, fragments)
})
elif leaf['kind'] == 'FragmentSpread':
elif leaf.kind == 'fragment_spread':
field.update(collect_query_fields(fragments[leaf['name']['value']],
fragments))
elif leaf['kind'] == 'InlineFragment':
elif leaf.kind == 'inline_fragment':
field.update({
leaf["type_condition"]["name"]['value']: collect_query_fields(leaf, fragments)
leaf.type_condition.name.value: collect_query_fields(leaf, fragments)
})
pass

return field

Expand All @@ -156,7 +160,7 @@ def get_query_fields(info):
"""

fragments = {}
node = ast_to_dict(info.field_asts[0])
node = ast_to_dict(info.field_nodes[0])

for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)
Expand All @@ -167,6 +171,24 @@ def get_query_fields(info):
return query


def ast_to_dict(node, include_loc=False):
if isinstance(node, FieldNode):
d = {"kind": node.__class__.__name__}
if hasattr(node, "keys"):
for field in node.keys:
d[field] = ast_to_dict(getattr(node, field), include_loc)

if include_loc and hasattr(node, "loc") and node.loc:
d["loc"] = {"start": node.loc.start, "end": node.loc.end}

return d

elif isinstance(node, list):
return [ast_to_dict(item, include_loc) for item in node]

return node


def find_skip_and_limit(first, last, after, before, count):
reverse = False
skip = 0
Expand Down
17 changes: 8 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
coveralls==1.11.1; python_version <= '3.5'
coveralls==2.1.2; python_version > '3.5'
coveralls==3.0.1; python_version > '3.5'
flake8==3.7.9
flake8-per-file-ignores==0.6
future==0.18.2
graphene>=2.1.8,<3
iso8601==0.1.13
graphene==3.0b7
promise==2.3
mongoengine==0.19.1; python_version <= '3.5'
mongoengine==0.22.1; python_version > '3.5'
mongomock==3.22.0
pymongo==3.11.2
mongoengine==0.23.0; python_version > '3.5'
mongomock==3.22.1
pytest==4.6.11; python_version <= '3.5'
pytest==6.2.1; python_version > '3.5'
pytest==6.2.2; python_version > '3.5'
pytest-cov==2.8.1; python_version == '3.5' or python_version == '3.4'
pytest-cov==2.10.1; python_version < '3.4' or python_version > '3.5'
singledispatch==3.4.0.3
pytest-cov==2.11.1; python_version < '3.4' or python_version > '3.5'
singledispatch==3.6.1
# https://stackoverflow.com/a/58189684/9041712
attrs==20.2.0
futures; python_version < '3.0'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
packages=find_packages(exclude=["tests"]),
install_requires=[
"graphene>=2.1.3,<3",
"mongoengine>=0.15.0",
"mongoengine>=0.23.0",
"singledispatch>=3.4.0.3",
"iso8601>=0.1.12",
'futures; python_version < "3.0"'
Expand Down

0 comments on commit 3f2c4c8

Please sign in to comment.