Skip to content

Commit

Permalink
Subscription revamp (#1235)
Browse files Browse the repository at this point in the history
* Integrate async tests into main code

* Added full support for subscriptions

* Fixed syntax using black

* Fixed typo
  • Loading branch information
syrusakbary committed Jul 28, 2020
1 parent 2130005 commit d085c88
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 64 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ install-dev:
pip install -e ".[dev]"

test:
py.test graphene examples tests_asyncio
py.test graphene examples

.PHONY: docs ## Generate docs
docs: install-dev
Expand All @@ -20,8 +20,8 @@ docs-live: install-dev

.PHONY: format
format:
black graphene examples setup.py tests_asyncio
black graphene examples setup.py

.PHONY: lint
lint:
flake8 graphene examples setup.py tests_asyncio
flake8 graphene examples setup.py
4 changes: 2 additions & 2 deletions graphene/relay/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def connection_resolver(cls, resolver, connection_type, root, info, **args):
on_resolve = partial(cls.resolve_connection, connection_type, args)
return maybe_thenable(resolved, on_resolve)

def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
def wrap_resolve(self, parent_resolver):
resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
return partial(self.connection_resolver, resolver, self.type)


Expand Down
4 changes: 2 additions & 2 deletions graphene/relay/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def id_resolver(parent_resolver, node, root, info, parent_type_name=None, **args
parent_type_name = parent_type_name or info.parent_type.name
return node.to_global_id(parent_type_name, type_id) # root._meta.name

def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
return partial(
self.id_resolver,
parent_resolver,
Expand All @@ -60,7 +60,7 @@ def __init__(self, node, type_=False, **kwargs):
**kwargs,
)

def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
return partial(self.node_type.node_resolver, get_type(self.field_type))


Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions graphene/relay/tests/test_global_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def test_global_id_allows_overriding_of_node_and_required():
def test_global_id_defaults_to_info_parent_type():
my_id = "1"
gid = GlobalID()
id_resolver = gid.get_resolver(lambda *_: my_id)
id_resolver = gid.wrap_resolve(lambda *_: my_id)
my_global_id = id_resolver(None, Info(User))
assert my_global_id == to_global_id(User._meta.name, my_id)


def test_global_id_allows_setting_customer_parent_type():
my_id = "1"
gid = GlobalID(parent_type=User)
id_resolver = gid.get_resolver(lambda *_: my_id)
id_resolver = gid.wrap_resolve(lambda *_: my_id)
my_global_id = id_resolver(None, None)
assert my_global_id == to_global_id(User._meta.name, my_id)
File renamed without changes.
22 changes: 21 additions & 1 deletion graphene/types/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .structures import NonNull
from .unmountedtype import UnmountedType
from .utils import get_type
from ..utils.deprecated import warn_deprecation

base_type = type

Expand Down Expand Up @@ -114,5 +115,24 @@ def __init__(
def type(self):
return get_type(self._type)

def get_resolver(self, parent_resolver):
get_resolver = None

def wrap_resolve(self, parent_resolver):
"""
Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
(parent_resolver) if the Field definition has no resolver.
"""
if self.get_resolver is not None:
warn_deprecation(
"The get_resolver method is being deprecated, please rename it to wrap_resolve."
)
return self.get_resolver(parent_resolver)

return self.resolver or parent_resolver

def wrap_subscribe(self, parent_subscribe):
"""
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
(parent_subscribe) if the Field definition has no subscribe.
"""
return parent_subscribe
73 changes: 53 additions & 20 deletions graphene/types/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
parse,
print_schema,
subscribe,
validate,
ExecutionResult,
GraphQLArgument,
GraphQLBoolean,
GraphQLError,
GraphQLEnumValue,
GraphQLField,
GraphQLFloat,
Expand Down Expand Up @@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types)


# We use this resolver for subscriptions
def identity_resolve(root, info):
return root


class TypeMap(dict):
def __init__(
self,
Expand Down Expand Up @@ -307,30 +315,48 @@ def create_fields_for_type(self, graphene_type, is_input_type=False):
if isinstance(arg.type, NonNull)
else arg.default_value,
)
subscribe = field.wrap_subscribe(
self.get_function_for_type(
graphene_type, f"subscribe_{name}", name, field.default_value,
)
)

# If we are in a subscription, we use (by default) an
# identity-based resolver for the root, rather than the
# default resolver for objects/dicts.
if subscribe:
field_default_resolver = identity_resolve
elif issubclass(graphene_type, ObjectType):
default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
field_default_resolver = partial(
default_resolver, name, field.default_value
)
else:
field_default_resolver = None

resolve = field.wrap_resolve(
self.get_function_for_type(
graphene_type, f"resolve_{name}", name, field.default_value
)
or field_default_resolver
)

_field = GraphQLField(
field_type,
args=args,
resolve=field.get_resolver(
self.get_resolver_for_type(
graphene_type, f"resolve_{name}", name, field.default_value
)
),
subscribe=field.get_resolver(
self.get_resolver_for_type(
graphene_type,
f"subscribe_{name}",
name,
field.default_value,
)
),
resolve=resolve,
subscribe=subscribe,
deprecation_reason=field.deprecation_reason,
description=field.description,
)
field_name = field.name or self.get_name(name)
fields[field_name] = _field
return fields

def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
def get_function_for_type(self, graphene_type, func_name, name, default_value):
"""Gets a resolve or subscribe function for a given ObjectType"""
if not issubclass(graphene_type, ObjectType):
return
resolver = getattr(graphene_type, func_name, None)
Expand All @@ -350,11 +376,6 @@ def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
if resolver:
return get_unbound_function(resolver)

default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
return partial(default_resolver, name, default_value)

def resolve_type(self, resolve_type_func, type_name, root, info, _type):
type_ = resolve_type_func(root, info)

Expand Down Expand Up @@ -476,7 +497,19 @@ async def execute_async(self, *args, **kwargs):
return await graphql(self.graphql_schema, *args, **kwargs)

async def subscribe(self, query, *args, **kwargs):
document = parse(query)
"""Execute a GraphQL subscription on the schema asynchronously."""
# Do parsing
try:
document = parse(query)
except GraphQLError as error:
return ExecutionResult(data=None, errors=[error])

# Do validation
validation_errors = validate(self.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)

# Execute the query
kwargs = normalize_execute_kwargs(kwargs)
return await subscribe(self.graphql_schema, document, *args, **kwargs)

Expand Down
56 changes: 56 additions & 0 deletions graphene/types/tests/test_subscribe_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pytest import mark

from graphene import ObjectType, Int, String, Schema, Field


class Query(ObjectType):
hello = String()

def resolve_hello(root, info):
return "Hello, world!"


class Subscription(ObjectType):
count_to_ten = Field(Int)

async def subscribe_count_to_ten(root, info):
count = 0
while count < 10:
count += 1
yield count


schema = Schema(query=Query, subscription=Subscription)


@mark.asyncio
async def test_subscription():
subscription = "subscription { countToTen }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countToTen"]
assert count == 10


@mark.asyncio
async def test_subscription_fails_with_invalid_query():
# It fails if the provided query is invalid
subscription = "subscription { "
result = await schema.subscribe(subscription)
assert not result.data
assert result.errors
assert "Syntax Error: Expected Name, found <EOF>" in str(result.errors[0])


@mark.asyncio
async def test_subscription_fails_when_query_is_not_valid():
# It can't subscribe to two fields at the same time, triggering a
# validation error.
subscription = "subscription { countToTen, b: countToTen }"
result = await schema.subscribe(subscription)
assert not result.data
assert result.errors
assert "Anonymous Subscription must select only one top level field." in str(
result.errors[0]
)
33 changes: 0 additions & 33 deletions tests_asyncio/test_subscribe.py

This file was deleted.

2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ deps =
setenv =
PYTHONPATH = .:{envdir}
commands =
py{36,37}: pytest --cov=graphene graphene examples tests_asyncio {posargs}
py{36,37}: pytest --cov=graphene graphene examples {posargs}

[testenv:pre-commit]
basepython=python3.7
Expand Down

0 comments on commit d085c88

Please sign in to comment.