Skip to content
This repository has been archived by the owner on Sep 6, 2022. It is now read-only.

Commit

Permalink
Externalize GraphQL Global IDs instead of internal NDB keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ekampf committed Jun 14, 2016
1 parent 022baa4 commit 3dfa734
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 11 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
History
-------

0.1.7 (TBD)
---------------------


0.1.6 (2016-06-10)
---------------------
* Changing development status to Beta
Expand Down
2 changes: 1 addition & 1 deletion graphene_gae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

__author__ = 'Eran Kampf'
__version__ = '0.1.6'
__version__ = '0.1.7'

__all__ = [
NdbObjectType,
Expand Down
36 changes: 34 additions & 2 deletions graphene_gae/ndb/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from graphene.core.exceptions import SkipField
from graphene.core.types.base import FieldType
from graphene.core.types.scalars import Boolean, Int, String
from graphql_relay import to_global_id

__author__ = 'ekampf'

Expand Down Expand Up @@ -101,11 +102,42 @@ def __init__(self, name, *args, **kwargs):
def default_resolver(self, node, args, info):
entity = node.instance
key = getattr(entity, self.name)
if not key:
return None

if isinstance(key, list):
return [k.urlsafe() for k in key]
t = self.__get_key_internal_type(key[0], info.schema.graphene_schema)
return [to_global_id(t.name, k.urlsafe()) for k in key]

return key.urlsafe() if key else None
t = self.__get_key_internal_type(key, info.schema.graphene_schema)
return to_global_id(t.name, key.urlsafe()) if key else None

def __get_key_internal_type(self, key, schema):
_type = self.__find_key_object_type(key, schema)
if not _type and self.parent._meta.only_fields:
raise Exception(
"Model %r is not accessible by the schema. "
"You can either register the type manually "
"using @schema.register. "
"Or disable the field in %s" % (
key.kind(),
self.parent,
)
)

if not _type:
raise SkipField()

return schema.T(_type)

def __find_key_object_type(self, key, schema):
for _type in schema.types.values():
type_model = hasattr(_type, '_meta') and getattr(_type._meta, 'model', None)
if not type_model:
continue

if key.kind() == type_model or key.kind() == type_model.__name__:
return _type


class NdbKeyField(FieldType):
Expand Down
14 changes: 7 additions & 7 deletions tests/_ndb/test_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from graphql_relay import to_global_id
from tests.base_test import BaseTest

import graphene
Expand Down Expand Up @@ -280,7 +281,7 @@ def testQuery_keyProperty(self):
query ArticleWithAuthorID {
articles {
headline
authorKey
authorId
author {
name, email
}
Expand All @@ -293,7 +294,8 @@ def testQuery_keyProperty(self):
article = dict(result.data['articles'][0])
author = dict(article['author'])
self.assertDictEqual(author, {'name': u'john dow', 'email': u'john@dow.com'})
self.assertDictContainsSubset(dict(headline='h1', authorKey=author_key.urlsafe()), article)
self.assertEqual('h1', article['headline'])
self.assertEqual(to_global_id('AuthorType', author_key.urlsafe()), article['authorId'])

def testQuery_repeatedKeyProperty(self):
tk1 = Tag(name="t1").put()
Expand All @@ -302,14 +304,12 @@ def testQuery_repeatedKeyProperty(self):
tk4 = Tag(name="t4").put()
Article(headline="h1", summary="s1", tags=[tk1, tk2, tk3, tk4]).put()

print str(schema)

result = schema.execute('''
query ArticleWithAuthorID {
articles {
headline
authorKey
tagKeys
authorId
tagIds
tags {
name
}
Expand All @@ -320,7 +320,7 @@ def testQuery_repeatedKeyProperty(self):
self.assertEmpty(result.errors)

article = dict(result.data['articles'][0])
self.assertListEqual(map(lambda k: k.urlsafe(), [tk1, tk2, tk3, tk4]), article['tagKeys'])
self.assertListEqual(map(lambda k: to_global_id('TagType', k.urlsafe()), [tk1, tk2, tk3, tk4]), article['tagIds'])

self.assertLength(article['tags'], 4)
for i in range(0, 3):
Expand Down
2 changes: 1 addition & 1 deletion tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_filtered_tasks(self, url=None, name=None, queue_names=None):

# region Extra Assertions
def assertEmpty(self, l, msg=None):
self.assertEqual(0, len(list(l)), msg=msg)
self.assertEqual(0, len(list(l)), msg=msg or str(l))

def assertLength(self, l, expectation, msg=None):
self.assertEqual(len(l), expectation, msg)
Expand Down

0 comments on commit 3dfa734

Please sign in to comment.