Permalink
Browse files

working on pubsub app

  • Loading branch information...
1 parent deeb138 commit c3c4b63d510eb47d5ca4bcfa44732f72343617f9 quantmind committed Sep 13, 2012
View
84 stdnet/apps/pubsub.py
@@ -1,54 +1,76 @@
+import logging
+
from inspect import isclass
+from collections import deque
from stdnet import getdb, AsyncObject
from stdnet.utils.encoders import Json
from stdnet.utils import is_string
+
+
+logger = logging.getLogger('stdnet.pubsub')
+
+
+class PubSubBase(object):
+ pickler = Json
-
-class Publisher(object):
- '''Class which publish messages to message queues.'''
- def __init__(self, server = None, pickler = Json):
+ def __init__(self, server=None, pickler=None):
+ pickler = pickler or self.pickler
if isclass(pickler):
pickler = pickler()
self.pickler = pickler
- self.client = getdb(server).client
+ self.server = getdb(server)
+
+class Publisher(PubSubBase):
+ '''Class which publish messages to message queues.'''
def publish(self, channel, data):
data = self.pickler.dumps(data)
- #return self.backend.publish(channel, data)
- return self.client.execute_command('PUBLISH', channel, data)
+ return self.server.publish(channel, data)
-class Subscriber(AsyncObject):
- '''Subscribe to '''
- def __init__(self, server = None, pickler = Json):
- if isclass(pickler):
- pickler = pickler()
- self.pickler = pickler
- self.client = getdb(server).subscriber()
+class Subscriber(PubSubBase):
+ '''A subscriber to channels'''
+ def __init__(self, server=None, pickler=None):
+ super(Subscriber, self).__init__(server, pickler)
+ self.channels = {}
+ self.patterns = {}
+ self._subscriber = self.server.subscriber(
+ message_callback=self.message_callback)
def disconnect(self):
- self.client.disconnect()
+ self._subscriber.disconnect()
def subscription_count(self):
- return self.client.subscription_count()
+ return self._subscriber.subscription_count()
- def subscribe(self, channels):
- return self.client.subscribe(channels)
+ def subscribe(self, *channels):
+ return self._subscriber.subscribe(self.channel_list(channels))
- def unsubscribe(self, channels = None):
- return self.client.unsubscribe(channels)
-
- def psubscribe(self, channels):
- return self.client.psubscribe(channels)
+ def unsubscribe(self, *channels):
+ return self._subscriber.unsubscribe(self.channel_list(channels))
- def punsubscribe(self, channels = None):
- return self.client.punsubscribe(channels)
+ def psubscribe(self, *channels):
+ return self._subscriber.psubscribe(self.channel_list(channels))
- def pull(self, timeout = None, count = None):
- '''Retrieve new messages from the subscribed channels.
+ def punsubscribe(self, *channels):
+ return self._subscriber.punsubscribe(self.channel_list(channels))
-:parameter timeout: Optional timeout in seconds.
-:parameter count: Optional number of messages to retrieve.'''
- return self.client.pull(timeout, count, self.pickler.loads)
-
+ def message_callback(self, command, channel, message=None):
+ if command == 'subscribe':
+ self.channels[channel] = deque()
+ elif command == 'unsubscribe':
+ self.channels.pop(channel, None)
+ elif channel in self.channels:
+ self.channels.append(message)
+ else:
+ logger.warn('Got message for unsubscribed channel "%s"' % channel)
+
+ def channel_list(self, channels):
+ ch = []
+ for channel in channels:
+ if not isinstance(channel, (list, tuple)):
+ ch.append(channel)
+ else:
+ ch.extend(channel)
+ return ch
View
39 stdnet/backends/base.py
@@ -203,7 +203,7 @@ class BackendDataServer(object):
structure_module = None
struct_map = {}
- def __init__(self, name, address, pickler = None,
+ def __init__(self, name, address, pickler=None,
charset='utf-8', connection_string='',
prefix=None, **params):
self.__name = name
@@ -233,9 +233,6 @@ def __eq__(self, other):
def issame(self, other):
return self.client == other.client
- def cursor(self, pipelined=False):
- return self
-
def disconnect(self):
'''Disconnect the connection.'''
pass
@@ -244,13 +241,6 @@ def __repr__(self):
return self.connection_string
__str__ = __repr__
- def isempty(self):
- '''Returns ``True`` if the database has no keys.'''
- keys = self.keys()
- if not hasattr(keys,'__len__'):
- keys = list(keys)
- return len(keys)
-
def make_objects(self, meta, data, related_fields=None):
'''Generator of :class:`stdnet.odm.StdModel` instances with data
from database.
@@ -289,8 +279,11 @@ def make_objects(self, meta, data, related_fields=None):
def objects_from_db(self, meta, data, related_fields=None):
return list(self.make_objects(meta, data, related_fields))
- def structure(self, instance, client = None):
- '''Create a backend :class:`stdnet.odm.Structure` handler.'''
+ def structure(self, instance, client=None):
+ '''Create a backend :class:`stdnet.odm.Structure` handler.
+
+:parameter instance: a :class:`stdnet.odm.Structure`
+:parameter client: Optional client handler'''
struct = self.struct_map.get(instance._meta.name)
if struct is None:
raise ModelNotAvailable('structure "{0}" is not available for\
@@ -315,36 +308,40 @@ def basekey(self, meta, *args):
# PURE VIRTUAL METHODS
- def setup_connection(self, address): # pragma: no cover
+ def setup_connection(self, address):
'''Callback during initialization. Implementation should override
this function for customizing their handling of connection parameters. It
must return a instance of the backend handler.'''
raise NotImplementedError()
- def execute_session(self, session, callback): # pragma: no cover
+ def execute_session(self, session, callback):
'''Execute a :class:`stdnet.odm.Session` in the backend server.'''
raise NotImplementedError()
- def model_keys(self, meta): # pragma: no cover
+ def model_keys(self, meta):
'''Return a list of database keys used by model *model*'''
raise NotImplementedError()
- def instance_keys(self, obj): # pragma: no cover
+ def instance_keys(self, obj):
'''Return a list of database keys used by instance *obj*'''
raise NotImplementedError()
- def as_cache(self): # pragma: no cover
+ def as_cache(self):
raise NotImplementedError('This backend cannot be used as cache')
- def clear(self): # pragma: no cover
+ def clear(self):
"""Remove *all* values from the database at once."""
raise NotImplementedError()
- def flush(self, meta=None, pattern=None): # pragma: no cover
+ def flush(self, meta=None, pattern=None):
'''Flush all model keys from the database'''
raise NotImplementedError()
- def subscriber(self): # pragma: no cover
+ def publish(self, channel, message):
+ '''Publish a message to a *channel*'''
+ raise NotImplementedError('This backend cannot publish messages')
+
+ def subscriber(self, **kwargs):
raise NotImplementedError()
View
2 stdnet/backends/main.py
@@ -64,7 +64,7 @@ def get_connection_string(scheme, address, params):
def getdb(backend_uri=None, **kwargs):
'''get a backend database'''
- if isinstance(backend_uri,BackendDataServer):
+ if isinstance(backend_uri, BackendDataServer):
return backend_uri
backend_uri = backend_uri or settings.DEFAULT_BACKEND
if not backend_uri:
View
10 stdnet/backends/redisb.py
@@ -782,9 +782,6 @@ def get(self, id, default = None):
return self.pickler.loads(v)
else:
return default
-
- def cursor(self, pipelined = False):
- return self.client.pipeline() if pipelined else self.client
def disconnect(self):
self.client.connection_pool.disconnect()
@@ -969,5 +966,8 @@ def flush_structure(self, sm, pipe):
pipe.add_callback(
partial(structure_session_callback,sm))
- def subscriber(self):
- return redis.Subscriber(self.client)
+ def publish(self, channel, message):
+ return self.client.execute_command('PUBLISH', channel, message)
+
+ def subscriber(self, **kwargs):
+ return redis.Subscriber(self.client, **kwargs)
View
13 stdnet/lib/redis/connection.py
@@ -98,7 +98,7 @@ def __init__(self, client, connection, command_name, args,
self.connection = connection
self.command_name = command_name
self.args = args
- self.release_connection = release_connection
+ self._release_connection = release_connection
self.options = options
self.tried = 0
self._raw_response = []
@@ -117,6 +117,13 @@ def __init__(self, client, connection, command_name, args,
self.command = None
@property
+ def release_connection(self):
+ if self.connection.streaming:
+ return False
+ else:
+ return self._release_connection
+
+ @property
def num_responses(self):
if self.command_name:
return 1
@@ -265,12 +272,14 @@ class Connection(object):
def __init__(self, pool, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', reader_class=None,
- decode = False, **kwargs):
+ decode = False, streaming=False,
+ **kwargs):
self.pool = pool
self.password = password
self.socket_timeout = socket_timeout
self.encoding = encoding
self.encoding_errors = encoding_errors
+ self.streaming = streaming
self.__sock = None
if reader_class is None:
if settings.REDIS_PY_PARSER:
View
148 stdnet/lib/redis/pubsub.py
@@ -1,41 +1,25 @@
-from inspect import isclass
+from collections import deque
-from stdnet import getdb
-from stdnet.utils.encoders import Json
from stdnet.utils import is_string
from .client import RedisProxy
-__all__ = ['Publisher','Subscriber']
-
-class Publisher(object):
- '''Class which publish messages to message queues.'''
- def __init__(self, server = None, pickler = Json):
- if isclass(pickler):
- pickler = pickler()
- self.pickler = pickler
- self.client = getdb(server).client
-
- def publish(self, channel, data):
- data = self.pickler.dumps(data)
- #return self.backend.publish(channel, data)
- return self.client.execute_command('PUBLISH', channel, data)
+__all__ = ['Subscriber']
class Subscriber(RedisProxy):
'''Subscribe to '''
- def __init__(self, client):
+ subscribe_commands = frozenset(('subscribe', 'psubscribe'))
+ unsubscribe_commands = frozenset(('unsubscribe', 'punsubscribe'))
+ message_commands = frozenset(('message', 'pmessage'))
+
+ def __init__(self, client, message_callback=None):
super(Subscriber,self).__init__(client)
self.connection = None
+ self.command_queue = deque()
+ self.message_callback = message_callback
self._subscription_count = 0
- self._request = None
- self.channels = set()
- self.patterns = set()
- self.options = {'release_connection': False}
- self.subscribe_commands = set(
- (b'subscribe', b'psubscribe', b'unsubscribe', b'punsubscribe')
- )
def __del__(self):
try:
@@ -54,98 +38,52 @@ def subscription_count(self):
return self._subscription_count
def subscribe(self, channels):
- return self.execute_command('SUBSCRIBE', channels, self.channels)
+ return self.execute_command('subscribe', channels)
def unsubscribe(self, channels):
- return self.execute_command('UNSUBSCRIBE', channels, self.channels,
- False)
+ return self.execute_command('unsubscribe', channels)
def psubscribe(self, channels):
- return self.execute_command('PSUBSCRIBE', channels, self.patterns)
+ return self.execute_command('psubscribe', channels)
def punsubscribe(self, channels):
- return self.execute_command('PUNSUBSCRIBE', channels, self.patterns,
- False)
+ return self.execute_command('punsubscribe', channels)
- def request(self):
- if self._request is None:
- if self.connection is None:
- raise ValueErrior('No connection')
- self._request = self.connection.request_class(self.client,
- self.connection, False, (),
- release_connection = False)
- return self._request
+ # INTERNALS
- def execute_command(self, command, channels, container, add = True):
- "Internal function which execute a publish/subscribe command."
- channels = channels or ()
- if is_string(channels):
- channels = [channels]
- if add:
- for c in channels:
- container.add(c)
- else:
- if not channels:
- container.clear()
- else:
- for c in channels:
- try:
- container.remove(c)
- except KeyError:
- pass
+ def execute_command(self, command, channels):
+ cmd = (command, channels)
if self.connection is None:
- self.connection = self.client.connection_pool.get_connection()
- connection = self.connection
- try:
- return connection.execute_command(self, command,
- *channels, **self.options)
- except redis.ConnectionError:
- connection.disconnect()
- # Connect manually here. If the Redis server is down, this will
- # fail and raise a ConnectionError as desired.
- connection.connect()
- # resubscribe to all channels and patterns before
- # resending the current command
- for channel in self.channels:
- self.subscribe(channel)
- for pattern in self.patterns:
- self.psubscribe(pattern)
- connection.send_command(command, channels)
- return self.parse_response()
-
+ if command in self.subscribe_commands:
+ self.connection = self.client.connection_pool.get_connection()
+ self.command_queue.append(cmd)
+ return self._execute_next()
+ else:
+ self.command_queue.append(cmd)
+ return cmd
+
+ def _execute_next(self):
+ if self.command_queue:
+ command, channels = self.command_queue.popleft()
+ return self.connection.execute_command(self, command, *channels)
+
def parse_response(self, request):
"Parse the response from a publish/subscribe command"
response = request.response
- if response[0] in self.subscribe_commands:
+ #request.connection.streaming = True
+ #request.command = None
+ command = response[0].decode()
+ if command in self.subscribe_commands:
+ self.message_callback('subscribe', response[1].decode())
self._subscription_count = response[2]
- # if we've just unsubscribed from the remaining channels,
- # release the connection back to the pool
- if not self._subscription_count:
- self.disconnect()
- #data = self.pickler.dumps(data)
+ elif command in self.unsubscribe_commands:
+ self.message_callback('unsubscribe', response[1].decode())
+ self._subscription_count = response[2]
+ elif command in self.message_commands:
+ self.message_callback('message', self.get_message(response))
+ if not self._subscription_count:
+ self.disconnect()
return response
- def pull(self, timeout, count, loads):
- '''Retrieve new messages from the subscribed channels.
-
-:parameter timeout: a timeout in seconds'''
- c = 0
- request = self.request()
- while self.subscription_count and (count and c < count):
- r = request.read_response()
- c += 1
- if r[0] == b'pmessage':
- msg = {
- 'type': 'pmessage',
- 'pattern': r[1].decode('utf-8'),
- 'channel': r[2].decode('utf-8'),
- 'data': loads(r[3])
- }
- else:
- msg = {
- 'type': 'message',
- 'pattern': None,
- 'channel': r[1].decode('utf-8'),
- 'data': loads(r[2])
- }
- yield msg
+ def get_message(self, response):
+ return response
View
2 stdnet/odm/session.py
@@ -500,7 +500,7 @@ class Session(object):
class for querying. Default is :class:`Query`.
'''
_structures = {}
- def __init__(self, backend, query_class = None):
+ def __init__(self, backend, query_class=None):
self.backend = getdb(backend)
self.transaction = None
self._models = OrderedDict()
View
40 tests/regression/backend.py
@@ -0,0 +1,40 @@
+from stdnet import test, BackendDataServer, ModelNotAvailable,\
+ SessionNotAvailable
+from stdnet import odm
+
+
+class DummyBackendDataServer(BackendDataServer):
+
+ def setup_connection(self, address):
+ pass
+
+
+class TestBackend(test.TestCase):
+
+ def get_backend(self, name='?', address=('',0)):
+ return DummyBackendDataServer(name, address)
+
+ def testVirtuals(self):
+ self.assertRaises(NotImplementedError, BackendDataServer, '', '')
+ b = self.get_backend()
+ self.assertEqual(str(b), '')
+ self.assertFalse(b.clean(None))
+ self.assertRaises(NotImplementedError, b.execute_session, None, None)
+ self.assertRaises(NotImplementedError, b.model_keys, None)
+ self.assertRaises(NotImplementedError, b.instance_keys, None)
+ self.assertRaises(NotImplementedError, b.as_cache)
+ self.assertRaises(NotImplementedError, b.clear)
+ self.assertRaises(NotImplementedError, b.flush)
+ self.assertRaises(NotImplementedError, b.publish, '', '')
+
+ def testMissingStructure(self):
+ l = odm.List()
+ self.assertRaises(SessionNotAvailable, l.backend_structure)
+ session = odm.Session(backend=self.get_backend())
+ session.begin()
+ session.add(l)
+ self.assertRaises(ModelNotAvailable, l.backend_structure)
+
+
+
+
View
47 tests/regression/pubsub.py
@@ -1,26 +1,47 @@
-from stdnet import test
+import os
+
+from stdnet import test, getdb
from stdnet.apps.pubsub import Publisher, Subscriber
+from .backend import DummyBackendDataServer
+
+@test.skipUnless(os.environ['stdnet_test_suite'] == 'pulsar', 'Requires Pulsar')
class TestPubSub(test.TestCase):
def setUp(self):
self.s = Subscriber()
+
+ #def tearDown(self):
+ # self.s.unsubscribe()
+ # self.s.punsubscribe()
+
+ def publisher(self):
+ p = Publisher(self.backend)
+ self.assertTrue(p.pickler)
+ self.assertTrue(p.server)
+ return p
- def __tearDown(self):
- self.s.unsubscribe()
- self.s.punsubscribe()
+ def subscriber(self):
+ from stdnet.lib.redis.async import RedisConnection
+ b = getdb(self.backend.connection_string,
+ connection_class=RedisConnection)
+ s = Subscriber(b)
+ self.assertTrue(s.server)
+ return s
+
+ def testDummy(self):
+ p = Publisher(DummyBackendDataServer('',''))
+ self.assertRaises(NotImplementedError, p.publish, '', '')
- def __testPublisher(self):
- p = Publisher()
- self.assertTrue(p.pickler)
- self.assertTrue(p.client)
+ def testClasses(self):
+ p = self.publisher()
+ s = self.subscriber()
- def __testSimple(self):
- s = self.s
- p = Publisher()
- # subscribe to 'test' message queue
- s.subscribe('test')
+ def testSimple(self):
+ p = self.publisher()
+ s = self.subscriber()
+ yield s.subscribe('test')
self.assertEqual(s.subscription_count(), 1)
self.assertEqual(p.publish('test','hello world!'), 1)
res = list(s.pull(count = 1))

0 comments on commit c3c4b63

Please sign in to comment.