diff --git a/stdnet/apps/pubsub.py b/stdnet/apps/pubsub.py index 7b6d279..e63481f 100644 --- a/stdnet/apps/pubsub.py +++ b/stdnet/apps/pubsub.py @@ -1,54 +1,76 @@ +import logging + from inspect import isclass +from collections import deque from stdnet import getdb, AsyncObject from stdnet.utils.encoders import Json from stdnet.utils import is_string + + +logger = logging.getLogger('stdnet.pubsub') + + +class PubSubBase(object): + pickler = Json - -class Publisher(object): - '''Class which publish messages to message queues.''' - def __init__(self, server = None, pickler = Json): + def __init__(self, server=None, pickler=None): + pickler = pickler or self.pickler if isclass(pickler): pickler = pickler() self.pickler = pickler - self.client = getdb(server).client + self.server = getdb(server) + +class Publisher(PubSubBase): + '''Class which publish messages to message queues.''' def publish(self, channel, data): data = self.pickler.dumps(data) - #return self.backend.publish(channel, data) - return self.client.execute_command('PUBLISH', channel, data) + return self.server.publish(channel, data) -class Subscriber(AsyncObject): - '''Subscribe to ''' - def __init__(self, server = None, pickler = Json): - if isclass(pickler): - pickler = pickler() - self.pickler = pickler - self.client = getdb(server).subscriber() +class Subscriber(PubSubBase): + '''A subscriber to channels''' + def __init__(self, server=None, pickler=None): + super(Subscriber, self).__init__(server, pickler) + self.channels = {} + self.patterns = {} + self._subscriber = self.server.subscriber( + message_callback=self.message_callback) def disconnect(self): - self.client.disconnect() + self._subscriber.disconnect() def subscription_count(self): - return self.client.subscription_count() + return self._subscriber.subscription_count() - def subscribe(self, channels): - return self.client.subscribe(channels) + def subscribe(self, *channels): + return self._subscriber.subscribe(self.channel_list(channels)) - def unsubscribe(self, channels = None): - return self.client.unsubscribe(channels) - - def psubscribe(self, channels): - return self.client.psubscribe(channels) + def unsubscribe(self, *channels): + return self._subscriber.unsubscribe(self.channel_list(channels)) - def punsubscribe(self, channels = None): - return self.client.punsubscribe(channels) + def psubscribe(self, *channels): + return self._subscriber.psubscribe(self.channel_list(channels)) - def pull(self, timeout = None, count = None): - '''Retrieve new messages from the subscribed channels. + def punsubscribe(self, *channels): + return self._subscriber.punsubscribe(self.channel_list(channels)) -:parameter timeout: Optional timeout in seconds. -:parameter count: Optional number of messages to retrieve.''' - return self.client.pull(timeout, count, self.pickler.loads) - \ No newline at end of file + def message_callback(self, command, channel, message=None): + if command == 'subscribe': + self.channels[channel] = deque() + elif command == 'unsubscribe': + self.channels.pop(channel, None) + elif channel in self.channels: + self.channels.append(message) + else: + logger.warn('Got message for unsubscribed channel "%s"' % channel) + + def channel_list(self, channels): + ch = [] + for channel in channels: + if not isinstance(channel, (list, tuple)): + ch.append(channel) + else: + ch.extend(channel) + return ch diff --git a/stdnet/backends/base.py b/stdnet/backends/base.py index 1416203..a4640f7 100755 --- a/stdnet/backends/base.py +++ b/stdnet/backends/base.py @@ -203,7 +203,7 @@ class BackendDataServer(object): structure_module = None struct_map = {} - def __init__(self, name, address, pickler = None, + def __init__(self, name, address, pickler=None, charset='utf-8', connection_string='', prefix=None, **params): self.__name = name @@ -233,9 +233,6 @@ def __eq__(self, other): def issame(self, other): return self.client == other.client - def cursor(self, pipelined=False): - return self - def disconnect(self): '''Disconnect the connection.''' pass @@ -244,13 +241,6 @@ def __repr__(self): return self.connection_string __str__ = __repr__ - def isempty(self): - '''Returns ``True`` if the database has no keys.''' - keys = self.keys() - if not hasattr(keys,'__len__'): - keys = list(keys) - return len(keys) - def make_objects(self, meta, data, related_fields=None): '''Generator of :class:`stdnet.odm.StdModel` instances with data from database. @@ -289,8 +279,11 @@ def make_objects(self, meta, data, related_fields=None): def objects_from_db(self, meta, data, related_fields=None): return list(self.make_objects(meta, data, related_fields)) - def structure(self, instance, client = None): - '''Create a backend :class:`stdnet.odm.Structure` handler.''' + def structure(self, instance, client=None): + '''Create a backend :class:`stdnet.odm.Structure` handler. + +:parameter instance: a :class:`stdnet.odm.Structure` +:parameter client: Optional client handler''' struct = self.struct_map.get(instance._meta.name) if struct is None: raise ModelNotAvailable('structure "{0}" is not available for\ @@ -315,36 +308,40 @@ def basekey(self, meta, *args): # PURE VIRTUAL METHODS - def setup_connection(self, address): # pragma: no cover + def setup_connection(self, address): '''Callback during initialization. Implementation should override this function for customizing their handling of connection parameters. It must return a instance of the backend handler.''' raise NotImplementedError() - def execute_session(self, session, callback): # pragma: no cover + def execute_session(self, session, callback): '''Execute a :class:`stdnet.odm.Session` in the backend server.''' raise NotImplementedError() - def model_keys(self, meta): # pragma: no cover + def model_keys(self, meta): '''Return a list of database keys used by model *model*''' raise NotImplementedError() - def instance_keys(self, obj): # pragma: no cover + def instance_keys(self, obj): '''Return a list of database keys used by instance *obj*''' raise NotImplementedError() - def as_cache(self): # pragma: no cover + def as_cache(self): raise NotImplementedError('This backend cannot be used as cache') - def clear(self): # pragma: no cover + def clear(self): """Remove *all* values from the database at once.""" raise NotImplementedError() - def flush(self, meta=None, pattern=None): # pragma: no cover + def flush(self, meta=None, pattern=None): '''Flush all model keys from the database''' raise NotImplementedError() - def subscriber(self): # pragma: no cover + def publish(self, channel, message): + '''Publish a message to a *channel*''' + raise NotImplementedError('This backend cannot publish messages') + + def subscriber(self, **kwargs): raise NotImplementedError() diff --git a/stdnet/backends/main.py b/stdnet/backends/main.py index 3502ee0..8516595 100755 --- a/stdnet/backends/main.py +++ b/stdnet/backends/main.py @@ -64,7 +64,7 @@ def get_connection_string(scheme, address, params): def getdb(backend_uri=None, **kwargs): '''get a backend database''' - if isinstance(backend_uri,BackendDataServer): + if isinstance(backend_uri, BackendDataServer): return backend_uri backend_uri = backend_uri or settings.DEFAULT_BACKEND if not backend_uri: diff --git a/stdnet/backends/redisb.py b/stdnet/backends/redisb.py index aafa495..1646534 100755 --- a/stdnet/backends/redisb.py +++ b/stdnet/backends/redisb.py @@ -782,9 +782,6 @@ def get(self, id, default = None): return self.pickler.loads(v) else: return default - - def cursor(self, pipelined = False): - return self.client.pipeline() if pipelined else self.client def disconnect(self): self.client.connection_pool.disconnect() @@ -969,5 +966,8 @@ def flush_structure(self, sm, pipe): pipe.add_callback( partial(structure_session_callback,sm)) - def subscriber(self): - return redis.Subscriber(self.client) \ No newline at end of file + def publish(self, channel, message): + return self.client.execute_command('PUBLISH', channel, message) + + def subscriber(self, **kwargs): + return redis.Subscriber(self.client, **kwargs) \ No newline at end of file diff --git a/stdnet/lib/redis/connection.py b/stdnet/lib/redis/connection.py index 5fe6536..02e7b73 100644 --- a/stdnet/lib/redis/connection.py +++ b/stdnet/lib/redis/connection.py @@ -98,7 +98,7 @@ def __init__(self, client, connection, command_name, args, self.connection = connection self.command_name = command_name self.args = args - self.release_connection = release_connection + self._release_connection = release_connection self.options = options self.tried = 0 self._raw_response = [] @@ -116,6 +116,13 @@ def __init__(self, client, connection, command_name, args, else: self.command = None + @property + def release_connection(self): + if self.connection.streaming: + return False + else: + return self._release_connection + @property def num_responses(self): if self.command_name: @@ -265,12 +272,14 @@ class Connection(object): def __init__(self, pool, password=None, socket_timeout=None, encoding='utf-8', encoding_errors='strict', reader_class=None, - decode = False, **kwargs): + decode = False, streaming=False, + **kwargs): self.pool = pool self.password = password self.socket_timeout = socket_timeout self.encoding = encoding self.encoding_errors = encoding_errors + self.streaming = streaming self.__sock = None if reader_class is None: if settings.REDIS_PY_PARSER: diff --git a/stdnet/lib/redis/pubsub.py b/stdnet/lib/redis/pubsub.py index df04744..39a86c6 100644 --- a/stdnet/lib/redis/pubsub.py +++ b/stdnet/lib/redis/pubsub.py @@ -1,41 +1,25 @@ -from inspect import isclass +from collections import deque -from stdnet import getdb -from stdnet.utils.encoders import Json from stdnet.utils import is_string from .client import RedisProxy -__all__ = ['Publisher','Subscriber'] - -class Publisher(object): - '''Class which publish messages to message queues.''' - def __init__(self, server = None, pickler = Json): - if isclass(pickler): - pickler = pickler() - self.pickler = pickler - self.client = getdb(server).client - - def publish(self, channel, data): - data = self.pickler.dumps(data) - #return self.backend.publish(channel, data) - return self.client.execute_command('PUBLISH', channel, data) +__all__ = ['Subscriber'] class Subscriber(RedisProxy): '''Subscribe to ''' - def __init__(self, client): + subscribe_commands = frozenset(('subscribe', 'psubscribe')) + unsubscribe_commands = frozenset(('unsubscribe', 'punsubscribe')) + message_commands = frozenset(('message', 'pmessage')) + + def __init__(self, client, message_callback=None): super(Subscriber,self).__init__(client) self.connection = None + self.command_queue = deque() + self.message_callback = message_callback self._subscription_count = 0 - self._request = None - self.channels = set() - self.patterns = set() - self.options = {'release_connection': False} - self.subscribe_commands = set( - (b'subscribe', b'psubscribe', b'unsubscribe', b'punsubscribe') - ) def __del__(self): try: @@ -54,98 +38,52 @@ def subscription_count(self): return self._subscription_count def subscribe(self, channels): - return self.execute_command('SUBSCRIBE', channels, self.channels) + return self.execute_command('subscribe', channels) def unsubscribe(self, channels): - return self.execute_command('UNSUBSCRIBE', channels, self.channels, - False) + return self.execute_command('unsubscribe', channels) def psubscribe(self, channels): - return self.execute_command('PSUBSCRIBE', channels, self.patterns) + return self.execute_command('psubscribe', channels) def punsubscribe(self, channels): - return self.execute_command('PUNSUBSCRIBE', channels, self.patterns, - False) + return self.execute_command('punsubscribe', channels) - def request(self): - if self._request is None: - if self.connection is None: - raise ValueErrior('No connection') - self._request = self.connection.request_class(self.client, - self.connection, False, (), - release_connection = False) - return self._request + # INTERNALS - def execute_command(self, command, channels, container, add = True): - "Internal function which execute a publish/subscribe command." - channels = channels or () - if is_string(channels): - channels = [channels] - if add: - for c in channels: - container.add(c) - else: - if not channels: - container.clear() - else: - for c in channels: - try: - container.remove(c) - except KeyError: - pass + def execute_command(self, command, channels): + cmd = (command, channels) if self.connection is None: - self.connection = self.client.connection_pool.get_connection() - connection = self.connection - try: - return connection.execute_command(self, command, - *channels, **self.options) - except redis.ConnectionError: - connection.disconnect() - # Connect manually here. If the Redis server is down, this will - # fail and raise a ConnectionError as desired. - connection.connect() - # resubscribe to all channels and patterns before - # resending the current command - for channel in self.channels: - self.subscribe(channel) - for pattern in self.patterns: - self.psubscribe(pattern) - connection.send_command(command, channels) - return self.parse_response() - + if command in self.subscribe_commands: + self.connection = self.client.connection_pool.get_connection() + self.command_queue.append(cmd) + return self._execute_next() + else: + self.command_queue.append(cmd) + return cmd + + def _execute_next(self): + if self.command_queue: + command, channels = self.command_queue.popleft() + return self.connection.execute_command(self, command, *channels) + def parse_response(self, request): "Parse the response from a publish/subscribe command" response = request.response - if response[0] in self.subscribe_commands: + #request.connection.streaming = True + #request.command = None + command = response[0].decode() + if command in self.subscribe_commands: + self.message_callback('subscribe', response[1].decode()) self._subscription_count = response[2] - # if we've just unsubscribed from the remaining channels, - # release the connection back to the pool - if not self._subscription_count: - self.disconnect() - #data = self.pickler.dumps(data) + elif command in self.unsubscribe_commands: + self.message_callback('unsubscribe', response[1].decode()) + self._subscription_count = response[2] + elif command in self.message_commands: + self.message_callback('message', self.get_message(response)) + if not self._subscription_count: + self.disconnect() return response - def pull(self, timeout, count, loads): - '''Retrieve new messages from the subscribed channels. - -:parameter timeout: a timeout in seconds''' - c = 0 - request = self.request() - while self.subscription_count and (count and c < count): - r = request.read_response() - c += 1 - if r[0] == b'pmessage': - msg = { - 'type': 'pmessage', - 'pattern': r[1].decode('utf-8'), - 'channel': r[2].decode('utf-8'), - 'data': loads(r[3]) - } - else: - msg = { - 'type': 'message', - 'pattern': None, - 'channel': r[1].decode('utf-8'), - 'data': loads(r[2]) - } - yield msg + def get_message(self, response): + return response \ No newline at end of file diff --git a/stdnet/odm/session.py b/stdnet/odm/session.py index 52739ce..ca63206 100644 --- a/stdnet/odm/session.py +++ b/stdnet/odm/session.py @@ -500,7 +500,7 @@ class Session(object): class for querying. Default is :class:`Query`. ''' _structures = {} - def __init__(self, backend, query_class = None): + def __init__(self, backend, query_class=None): self.backend = getdb(backend) self.transaction = None self._models = OrderedDict() diff --git a/tests/regression/backend.py b/tests/regression/backend.py new file mode 100644 index 0000000..d7f7bbe --- /dev/null +++ b/tests/regression/backend.py @@ -0,0 +1,40 @@ +from stdnet import test, BackendDataServer, ModelNotAvailable,\ + SessionNotAvailable +from stdnet import odm + + +class DummyBackendDataServer(BackendDataServer): + + def setup_connection(self, address): + pass + + +class TestBackend(test.TestCase): + + def get_backend(self, name='?', address=('',0)): + return DummyBackendDataServer(name, address) + + def testVirtuals(self): + self.assertRaises(NotImplementedError, BackendDataServer, '', '') + b = self.get_backend() + self.assertEqual(str(b), '') + self.assertFalse(b.clean(None)) + self.assertRaises(NotImplementedError, b.execute_session, None, None) + self.assertRaises(NotImplementedError, b.model_keys, None) + self.assertRaises(NotImplementedError, b.instance_keys, None) + self.assertRaises(NotImplementedError, b.as_cache) + self.assertRaises(NotImplementedError, b.clear) + self.assertRaises(NotImplementedError, b.flush) + self.assertRaises(NotImplementedError, b.publish, '', '') + + def testMissingStructure(self): + l = odm.List() + self.assertRaises(SessionNotAvailable, l.backend_structure) + session = odm.Session(backend=self.get_backend()) + session.begin() + session.add(l) + self.assertRaises(ModelNotAvailable, l.backend_structure) + + + + \ No newline at end of file diff --git a/tests/regression/pubsub.py b/tests/regression/pubsub.py index bd72304..a8710c6 100644 --- a/tests/regression/pubsub.py +++ b/tests/regression/pubsub.py @@ -1,26 +1,47 @@ -from stdnet import test +import os + +from stdnet import test, getdb from stdnet.apps.pubsub import Publisher, Subscriber +from .backend import DummyBackendDataServer + +@test.skipUnless(os.environ['stdnet_test_suite'] == 'pulsar', 'Requires Pulsar') class TestPubSub(test.TestCase): def setUp(self): self.s = Subscriber() + + #def tearDown(self): + # self.s.unsubscribe() + # self.s.punsubscribe() + + def publisher(self): + p = Publisher(self.backend) + self.assertTrue(p.pickler) + self.assertTrue(p.server) + return p - def __tearDown(self): - self.s.unsubscribe() - self.s.punsubscribe() + def subscriber(self): + from stdnet.lib.redis.async import RedisConnection + b = getdb(self.backend.connection_string, + connection_class=RedisConnection) + s = Subscriber(b) + self.assertTrue(s.server) + return s + + def testDummy(self): + p = Publisher(DummyBackendDataServer('','')) + self.assertRaises(NotImplementedError, p.publish, '', '') - def __testPublisher(self): - p = Publisher() - self.assertTrue(p.pickler) - self.assertTrue(p.client) + def testClasses(self): + p = self.publisher() + s = self.subscriber() - def __testSimple(self): - s = self.s - p = Publisher() - # subscribe to 'test' message queue - s.subscribe('test') + def testSimple(self): + p = self.publisher() + s = self.subscriber() + yield s.subscribe('test') self.assertEqual(s.subscription_count(), 1) self.assertEqual(p.publish('test','hello world!'), 1) res = list(s.pull(count = 1))