Skip to content

Commit d022e66

Browse files
committed
Modify pubsub/sub_server classes for other executors
1 parent e8ebc23 commit d022e66

File tree

6 files changed

+81
-56
lines changed

6 files changed

+81
-56
lines changed

graphql_subscriptions/executors/asyncio.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# source: https://github.com/graphql-python/graphql-core/blob/master/graphql/execution/executors/asyncio.py
21
from __future__ import absolute_import
32

43
import asyncio
@@ -7,6 +6,7 @@
76
from asyncio import ensure_future
87
except ImportError:
98
# ensure_future is only implemented in Python 3.4.4+
9+
# Reference: https://github.com/graphql-python/graphql-core/blob/master/graphql/execution/executors/asyncio.py
1010
def ensure_future(coro_or_future, loop=None):
1111
"""Wrap a coroutine or an awaitable in a future.
1212
If the argument is a Future, it is returned directly.
@@ -23,8 +23,8 @@ def ensure_future(coro_or_future, loop=None):
2323
del task._source_traceback[-1]
2424
return task
2525
else:
26-
raise TypeError('A Future, a coroutine or an awaitable is\
27-
required')
26+
raise TypeError(
27+
'A Future, a coroutine or an awaitable is required')
2828

2929

3030
class AsyncioMixin(object):
@@ -44,21 +44,31 @@ def __init__(self, loop=None):
4444
def sleep(time):
4545
yield from asyncio.sleep(time)
4646

47+
@staticmethod
48+
@asyncio.coroutine
49+
def timer(callback, period):
50+
while True:
51+
callback()
52+
yield from asyncio.sleep(period)
53+
4754
@staticmethod
4855
def kill(future):
4956
future.cancel()
5057

51-
@staticmethod
52-
def join(future):
58+
def join(self, future):
5359
self.loop.run_until_complete(asyncio.wait_for(future))
5460

5561
def join_all(self):
56-
futures = self.futures
57-
self.futures = []
58-
self.loop.run_until_complete(asyncio.wait(futures))
62+
while self.futures:
63+
futures = self.futures
64+
self.futures = []
65+
self.loop.run_until_complete(asyncio.wait(futures))
5966
return futures
6067

6168
def execute(self, fn, *args, **kwargs):
62-
future = ensure_future(result, loop=self.loop)
63-
self.futures.append(future)
64-
return future
69+
result = fn(*args, **kwargs)
70+
if isinstance(result, asyncio.Future) or asyncio.iscoroutine(result):
71+
future = ensure_future(result, loop=self.loop)
72+
self.futures.append(future)
73+
return future
74+
return result

graphql_subscriptions/executors/gevent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ def __init__(self):
1919
def sleep(time):
2020
gevent.sleep(time)
2121

22+
@staticmethod
23+
def timer(callback, period):
24+
while True:
25+
callback()
26+
gevent.sleep(period)
27+
2228
@staticmethod
2329
def kill(greenlet):
2430
greenlet.kill()

graphql_subscriptions/subscription_manager/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, schema, pubsub, setup_funcs={}):
2323
self.pubsub = pubsub
2424
self.setup_funcs = setup_funcs
2525
self.subscriptions = {}
26-
self.max_subscription_id = 0
26+
self.max_subscription_id = 1
2727

2828
def publish(self, trigger_name, payload):
2929
self.pubsub.publish(trigger_name, payload)

graphql_subscriptions/subscription_manager/pubsub.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import redis
88

99
from ..executors.gevent import GeventExecutor
10+
from ..executors.asyncio import AsyncioExecutor
1011

1112

1213
class RedisPubsub(object):
@@ -17,44 +18,57 @@ def __init__(self,
1718
*args,
1819
**kwargs):
1920

20-
if hasattr(executor, 'socket'):
21-
redis.connection.socket = executor.socket
22-
self.coro = None
21+
if executor == AsyncioExecutor:
22+
try:
23+
import aredis
24+
except ImportError:
25+
print('You need the redis_client "aredis" for use w/ asyncio')
26+
redis_client = aredis
27+
else:
28+
redis_client = redis
29+
30+
if executor == GeventExecutor:
31+
redis_client.connection.socket = executor.socket
32+
2333
self.executor = executor()
34+
self.get_message_task = None
2435

25-
self.redis = redis.StrictRedis(host, port, *args, **kwargs)
26-
self.pubsub = self.redis.pubsub()
2736
self.subscriptions = {}
28-
self.sub_id_counter = 0
37+
self.sub_id_counter = 1
38+
39+
self.redis = redis_client.StrictRedis(host, port, *args, **kwargs)
40+
self.pubsub = self.redis.pubsub()
2941

3042
def publish(self, trigger_name, message):
31-
self.redis.publish(trigger_name, pickle.dumps(message))
43+
self.executor.execute(
44+
self.redis.publish, trigger_name, pickle.dumps(message))
3245
return True
3346

3447
def subscribe(self, trigger_name, on_message_handler, options):
3548
self.sub_id_counter += 1
3649
try:
3750
if trigger_name not in list(self.subscriptions.values())[0]:
38-
self.pubsub.subscribe(trigger_name)
51+
self.executor.execute(self.pubsub.subscribe, trigger_name)
3952
except IndexError:
40-
self.pubsub.subscribe(trigger_name)
53+
self.executor.execute(self.pubsub.subscribe, trigger_name)
4154
self.subscriptions[self.sub_id_counter] = [
4255
trigger_name, on_message_handler
4356
]
44-
if not self.coro:
45-
self.coro = self.executor.execute(self.wait_and_get_message)
57+
if not self.get_message_task:
58+
self.get_message_task = self.executor.execute(
59+
self.wait_and_get_message)
4660
return Promise.resolve(self.sub_id_counter)
4761

4862
def unsubscribe(self, sub_id):
4963
trigger_name, on_message_handler = self.subscriptions[sub_id]
5064
del self.subscriptions[sub_id]
5165
try:
5266
if trigger_name not in list(self.subscriptions.values())[0]:
53-
self.pubsub.unsubscribe(trigger_name)
67+
self.executor.execute(self.pubsub.unsubscribe, trigger_name)
5468
except IndexError:
55-
self.pubsub.unsubscribe(trigger_name)
69+
self.executor.execute(self.pubsub.unsubscribe, trigger_name)
5670
if not self.subscriptions:
57-
self.coro = self.executor.kill(self.coro)
71+
self.get_message_task = self.executor.kill(self.get_message_task)
5872

5973
def wait_and_get_message(self):
6074
while True:

graphql_subscriptions/subscription_transport_ws/base.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from promise import Promise
33
import json
44

5-
from ..executors.gevent import GeventExecutor
65
from .protocols import (SUBSCRIPTION_FAIL, SUBSCRIPTION_END, SUBSCRIPTION_DATA,
76
SUBSCRIPTION_START, SUBSCRIPTION_SUCCESS, KEEPALIVE,
87
INIT, INIT_SUCCESS, INIT_FAIL, GRAPHQL_SUBSCRIPTIONS)
@@ -38,11 +37,6 @@ def __init__(self,
3837

3938
super(BaseSubscriptionServer, self).__init__(websocket)
4039

41-
def timer(self, callback, period):
42-
while True:
43-
callback()
44-
self.executor.sleep(period)
45-
4640
def unsubscribe(self, graphql_sub_id):
4741
self.subscription_manager.unsubscribe(graphql_sub_id)
4842

@@ -62,7 +56,7 @@ def keep_alive_callback():
6256

6357
if self.keep_alive:
6458
keep_alive_timer = self.executor.execute(
65-
self.timer,
59+
self.executor.timer,
6660
keep_alive_callback,
6761
self.keep_alive)
6862

@@ -252,21 +246,22 @@ def subscription_end_promise_handler(result):
252246

253247
def send_subscription_data(self, sub_id, payload):
254248
message = {'type': SUBSCRIPTION_DATA, 'id': sub_id, 'payload': payload}
255-
self.ws.send(json.dumps(message))
249+
self.executor.execute(self.ws.send, json.dumps(message))
256250

257251
def send_subscription_fail(self, sub_id, payload):
258252
message = {'type': SUBSCRIPTION_FAIL, 'id': sub_id, 'payload': payload}
259-
self.ws.send(json.dumps(message))
253+
self.executor.execute(self.ws.send, json.dumps(message))
254+
# self.ws.send(json.dumps(message))
260255

261256
def send_subscription_success(self, sub_id):
262257
message = {'type': SUBSCRIPTION_SUCCESS, 'id': sub_id}
263-
self.ws.send(json.dumps(message))
258+
self.executor.execute(self.ws.send, json.dumps(message))
264259

265260
def send_init_result(self, result):
266-
self.ws.send(json.dumps(result))
261+
self.executor.execute(self.ws.send, json.dumps(result))
267262
if result.get('type') == INIT_FAIL:
268263
self.ws.close(1011)
269264

270265
def send_keep_alive(self):
271266
message = {'type': KEEPALIVE}
272-
self.ws.send(json.dumps(message))
267+
self.executor.execute(self.ws.send, json.dumps(message))

tests/test_subscription_manager.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def test_pubsub_subscribe_and_publish(pubsub, executor, test_input, expected):
3131
def message_callback(message):
3232
try:
3333
assert message == expected
34-
executor.kill(pubsub.coro)
34+
executor.kill(pubsub.get_message_task)
3535
except AssertionError as e:
3636
sys.exit(e)
3737

3838
def publish_callback(sub_id):
3939
assert pubsub.publish('a', test_input)
40-
executor.join(pubsub.coro)
40+
executor.join(pubsub.get_message_task)
4141

4242
p1 = pubsub.subscribe('a', message_callback, {})
4343
p2 = p1.then(publish_callback)
@@ -52,7 +52,7 @@ def unsubscribe_publish_callback(sub_id):
5252
pubsub.unsubscribe(sub_id)
5353
assert pubsub.publish('a', 'test')
5454
try:
55-
executor.join(pubsub.coro)
55+
executor.join(pubsub.get_message_task)
5656
except AttributeError:
5757
return
5858

@@ -190,13 +190,13 @@ def test_subscribe_with_valid_query_and_return_root_value(sub_mgr, executor):
190190
def callback(e, payload):
191191
try:
192192
assert payload.data.get('testSubscription') == 'good'
193-
executor.kill(sub_mgr.pubsub.coro)
193+
executor.kill(sub_mgr.pubsub.get_message_task)
194194
except AssertionError as e:
195195
sys.exit(e)
196196

197197
def publish_and_unsubscribe_handler(sub_id):
198198
sub_mgr.publish('testSubscription', 'good')
199-
executor.join(sub_mgr.pubsub.coro)
199+
executor.join(sub_mgr.pubsub.get_message_task)
200200
sub_mgr.unsubscribe(sub_id)
201201

202202
p1 = sub_mgr.subscribe(query, 'X', callback, {}, {}, None, None)
@@ -217,14 +217,14 @@ def callback(err, payload):
217217
assert True
218218
else:
219219
assert payload.data.get('testFilter') == 'good_filter'
220-
executor.kill(sub_mgr.pubsub.coro)
220+
executor.kill(sub_mgr.pubsub.get_message_task)
221221
except AssertionError as e:
222222
sys.exit(e)
223223

224224
def publish_and_unsubscribe_handler(sub_id):
225225
sub_mgr.publish('filter_1', {'filterBoolean': False})
226226
sub_mgr.publish('filter_1', {'filterBoolean': True})
227-
executor.join(sub_mgr.pubsub.coro)
227+
executor.join(sub_mgr.pubsub.get_message_task)
228228
sub_mgr.unsubscribe(sub_id)
229229

230230
p1 = sub_mgr.subscribe(query, 'Filter1', callback, {'filterBoolean': True},
@@ -246,15 +246,15 @@ def callback(err, payload):
246246
assert True
247247
else:
248248
assert payload.data.get('testFilter') == 'good_filter'
249-
executor.kill(sub_mgr.pubsub.coro)
249+
executor.kill(sub_mgr.pubsub.get_message_task)
250250
except AssertionError as e:
251251
sys.exit(e)
252252

253253
def publish_and_unsubscribe_handler(sub_id):
254254
sub_mgr.publish('filter_2', {'filterBoolean': False})
255255
sub_mgr.publish('filter_2', {'filterBoolean': True})
256256
try:
257-
executor.join(sub_mgr.pubsub.coro)
257+
executor.join(sub_mgr.pubsub.get_message_task)
258258
except:
259259
raise
260260
sub_mgr.unsubscribe(sub_id)
@@ -285,13 +285,13 @@ def callback(err, payload):
285285
except AssertionError as e:
286286
sys.exit(e)
287287
if non_local['trigger_count'] == 2:
288-
executor.kill(sub_mgr.pubsub.coro)
288+
executor.kill(sub_mgr.pubsub.get_message_task)
289289

290290
def publish_and_unsubscribe_handler(sub_id):
291291
sub_mgr.publish('not_a_trigger', {'filterBoolean': False})
292292
sub_mgr.publish('trigger_1', {'filterBoolean': True})
293293
sub_mgr.publish('trigger_2', {'filterBoolean': True})
294-
executor.join(sub_mgr.pubsub.coro)
294+
executor.join(sub_mgr.pubsub.get_message_task)
295295
sub_mgr.unsubscribe(sub_id)
296296

297297
p1 = sub_mgr.subscribe(query, 'multiTrigger', callback,
@@ -342,7 +342,7 @@ def unsubscribe_and_publish_handler(sub_id):
342342
sub_mgr.unsubscribe(sub_id)
343343
sub_mgr.publish('testSubscription', 'good')
344344
try:
345-
executor.join(sub_mgr.pubsub.coro)
345+
executor.join(sub_mgr.pubsub.get_message_task)
346346
except AttributeError:
347347
return
348348

@@ -395,14 +395,14 @@ def callback(err, payload):
395395
assert err.message == 'Variable "$uga" of required type\
396396
"Boolean!" was not provided.'
397397

398-
executor.kill(sub_mgr.pubsub.coro)
398+
executor.kill(sub_mgr.pubsub.get_message_task)
399399
except AssertionError as e:
400400
sys.exit(e)
401401

402402
def unsubscribe_and_publish_handler(sub_id):
403403
sub_mgr.publish('testSubscription', 'good')
404404
try:
405-
executor.join(sub_mgr.pubsub.coro)
405+
executor.join(sub_mgr.pubsub.get_message_task)
406406
except AttributeError:
407407
return
408408
sub_mgr.unsubscribe(sub_id)
@@ -426,14 +426,14 @@ def callback(err, payload):
426426
try:
427427
assert err is None
428428
assert payload.data.get('testContext') == 'trigger'
429-
executor.kill(sub_mgr.pubsub.coro)
429+
executor.kill(sub_mgr.pubsub.get_message_task)
430430
except AssertionError as e:
431431
sys.exit(e)
432432

433433
def unsubscribe_and_publish_handler(sub_id):
434434
sub_mgr.publish('context_trigger', 'ignored')
435435
try:
436-
executor.join(sub_mgr.pubsub.coro)
436+
executor.join(sub_mgr.pubsub.get_message_task)
437437
except AttributeError:
438438
return
439439
sub_mgr.unsubscribe(sub_id)
@@ -458,14 +458,14 @@ def callback(err, payload):
458458
try:
459459
assert payload is None
460460
assert str(err) == 'context error'
461-
executor.kill(sub_mgr.pubsub.coro)
461+
executor.kill(sub_mgr.pubsub.get_message_task)
462462
except AssertionError as e:
463463
sys.exit(e)
464464

465465
def unsubscribe_and_publish_handler(sub_id):
466466
sub_mgr.publish('context_trigger', 'ignored')
467467
try:
468-
executor.join(sub_mgr.pubsub.coro)
468+
executor.join(sub_mgr.pubsub.get_message_task)
469469
except AttributeError:
470470
return
471471
sub_mgr.unsubscribe(sub_id)

0 commit comments

Comments
 (0)