Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: subscription returns incorrect result when data loader is used during fields resolving. #287

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions graphql/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,23 @@ def subscribe_field(
)
)

return result.map(
functools.partial(
complete_value_catching_error,
exe_context,
return_type,
field_asts,
info,
path,
)
)
def complete_subscription_result(result):
def promise_executor(v):
return complete_value_catching_error(
exe_context,
return_type,
field_asts,
info,
path,
result,
)

promise = Promise.resolve(None).then(promise_executor)
exe_context.executor.wait_until_finished()

return promise.get()

return result.map(complete_subscription_result)


def resolve_or_error(
Expand Down
137 changes: 137 additions & 0 deletions graphql/execution/tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# type: ignore
from collections import namedtuple

import pytest
from promise import Promise
from promise.dataloader import DataLoader
from rx.subjects import Subject

from graphql import (
GraphQLObjectType,
Expand All @@ -10,6 +13,9 @@
GraphQLArgument,
GraphQLNonNull,
GraphQLSchema,
GraphQLString,
GraphQLList,
GraphQLInt,
parse,
execute,
)
Expand Down Expand Up @@ -169,3 +175,134 @@ class Context(object):
}
assert business_load_calls == [["1", "2"]]
assert location_load_calls == [["location-1", "location-2"]]


@pytest.mark.parametrize(
"executor",
[
SyncExecutor(),
# ThreadExecutor(),
],
)
def test_batches_subscription_result(executor):
# type: (SyncExecutor) -> None
Tag = namedtuple("Tag", "id,name")
Post = namedtuple("Post", "id,tag_id")

tags = {
1: Tag(id=1, name="#music"),
2: Tag(id=2, name="#beautiful"),
}

TagType = GraphQLObjectType(
"Tag",
lambda: {
"id": GraphQLField(GraphQLInt),
"name": GraphQLField(GraphQLString),
},
)

PostType = GraphQLObjectType(
"Post",
lambda: {
"id": GraphQLField(GraphQLInt),
"tag": GraphQLField(
TagType,
resolver=lambda root, info: info.context.tags_data_loader.load(
root.tag_id
),
),
},
)

new_posts_in_stream = Subject()

Subscription = GraphQLObjectType(
"Subscription",
lambda: {
"newPosts": GraphQLField(
GraphQLList(PostType),
resolver=lambda root, info: new_posts_in_stream,
),
},
)

schema = GraphQLSchema(
query=GraphQLObjectType(
"Query",
lambda: {"posts": GraphQLField(GraphQLList(PostType))},
),
subscription=Subscription,
)

doc = """
subscription {
newPosts {
id
tag {
id
name
}
}
}
"""
doc_ast = parse(doc)

load_calls = []

class TagsDataLoader(DataLoader):
def batch_load_fn(self, keys):
# type: (List[str]) -> Promise
load_calls.append(keys)
return Promise.resolve([tags[key] for key in keys])

class Context(object):
tags_data_loader = TagsDataLoader()

new_posts_out_stream = execute(
schema,
doc_ast,
None,
context_value=Context(),
allow_subscriptions=True,
executor=executor,
)

subscription_results = []
new_posts_out_stream.subscribe(subscription_results.append)

def create_new_posts(posts):
Context.tags_data_loader.clear_all()
new_posts_in_stream.on_next(posts)

create_new_posts(
[
Post(id=1, tag_id=1),
Post(id=2, tag_id=2),
Post(id=3, tag_id=1),
]
)
create_new_posts(
[
Post(id=4, tag_id=1),
Post(id=5, tag_id=1),
]
)

expected_data_1 = {
"newPosts": [
{"id": 1, "tag": {"id": 1, "name": "#music"}},
{"id": 2, "tag": {"id": 2, "name": "#beautiful"}},
{"id": 3, "tag": {"id": 1, "name": "#music"}},
]
}
expected_data_2 = {
"newPosts": [
{"id": 4, "tag": {"id": 1, "name": "#music"}},
{"id": 5, "tag": {"id": 1, "name": "#music"}},
]
}

assert subscription_results[0].data == expected_data_1
assert subscription_results[1].data == expected_data_2
assert load_calls == [[1, 2], [1]]