Skip to content

Commit

Permalink
Merge pull request #81 from pitrho/master
Browse files Browse the repository at this point in the history
Implementation of pubsub methods
  • Loading branch information
jamesls committed Oct 1, 2015
2 parents fff0029 + eb57803 commit 7e68c2a
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 11 deletions.
11 changes: 0 additions & 11 deletions README.rst
Expand Up @@ -182,17 +182,6 @@ string
* bitpos


pubsub
------

* punsubscribe
* subscribe
* publish
* pubsub
* psubscribe
* unsubscribe


Contributing
============

Expand Down
259 changes: 259 additions & 0 deletions fakeredis.py
Expand Up @@ -8,11 +8,25 @@
from datetime import datetime, timedelta
import operator
import sys
import time
import re

import redis
from redis.exceptions import ResponseError
import redis.client

try:
# Python 2.6, 2.7
from Queue import Queue, Empty
except:
# Python 3
from queue import Queue, Empty

PY2 = sys.version_info[0] == 2

if not PY2:
long = int


__version__ = '0.6.2'

Expand Down Expand Up @@ -170,6 +184,7 @@ def __init__(self, db=0, charset='utf-8', errors='strict', **kwargs):
self._db_num = db
self._encoding = charset
self._encoding_errors = errors
self._pubsubs = []

def flushdb(self):
DATABASES[self._db_num].clear()
Expand All @@ -179,6 +194,8 @@ def flushall(self):
for db in DATABASES:
DATABASES[db].clear()

del self._pubsubs[:]

# Basic key commands
def append(self, key, value):
self._db.setdefault(key, b'')
Expand Down Expand Up @@ -1273,6 +1290,30 @@ def transaction(self, func, *keys):
continue
raise redis.WatchError('Could not run transaction after 5 tries')

def pubsub(self):
"""
Returns a new FakePubSub instance
"""
ps = FakePubSub()
self._pubsubs.append(ps)

return ps

def publish(self, channel, message):
"""
Loops throug all available pubsub objects and publishes the
``message`` to then for the given ``channel``.
"""
count = 0
for i, ps in enumerate(self._pubsubs):
if not ps.subscribed:
del self._pubsubs[i]
continue

count += ps.put(channel, message, 'message')

return count


class FakeRedis(FakeStrictRedis):
def setex(self, name, value, time):
Expand Down Expand Up @@ -1393,3 +1434,221 @@ def multi(self):

def reset(self):
self.need_reset = False


class FakePubSub(object):

PUBLISH_MESSAGE_TYPES = ['message', 'pmessage']
SUBSCRIBE_MESSAGE_TYPES = ['subscribe', 'psubscribe']
UNSUBSCRIBE_MESSAGE_TYPES = ['unsubscribe', 'punsubscribe']
PATTERN_MESSAGE_TYPES = ['psubscribe', 'punsubscribe']

def __init__(self, *args, **kwargs):
self.channels = {}
self.patterns = {}
self._q = Queue()
self.subscribed = False

self.ignore_subscribe_messages = kwargs['ignore_subscribe_messages']\
if 'ignore_subscribe_messages' in kwargs else False

def put(self, channel, message, message_type, pattern=None):
"""
Utility function to be used as the publishing entrypoint for this
pubsub object
"""
if message_type in self.SUBSCRIBE_MESSAGE_TYPES or\
message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
return self._send(message_type, None, channel, message)

count = 0

# Send the message on the given channel
if channel in self.channels:
count += self._send(message_type, None, channel, message)

# See if any of the patterns match the given channel
for pattern, pattern_obj in iteritems(self.patterns):
match = re.match(pattern_obj['regex'], channel)
if match:
count += self._send('pmessage', pattern, channel, message)

return count

def _send(self, message_type, pattern, channel, data):
msg = {
'type': message_type,
'pattern': pattern,
'channel': channel.encode(),
'data': data.encode() if type(data) == str else data
}

self._q.put(msg)

return 1

def psubscribe(self, *args, **kwargs):
"""
Subcribe to channel patterns.
"""

def _subscriber(pattern, handler):
regex = self._parse_pattern(pattern)
return {
'regex': regex,
'handler': handler
}

total_subscriptions =\
len(self.channels.keys()) + len(self.patterns.keys())
self._subscribe(self.patterns, 'psubscribe', total_subscriptions,
_subscriber, *args, **kwargs)

def punsubscribe(self, *args):
"""
Unsubscribes from one or more patterns.
"""
total_subscriptions =\
len(self.channels.keys()) + len(self.patterns.keys())
self._usubscribe(self.patterns, 'punsubscribe', total_subscriptions,
*args)

def _parse_pattern(self, pattern):
temp_pattern = pattern
if '?' in temp_pattern:
temp_pattern = temp_pattern.replace('?', '.')

if '*' in temp_pattern:
temp_pattern = temp_pattern.replace('*', '.*')

if ']' in temp_pattern:
temp_pattern = temp_pattern.replace(']', ']?')

return temp_pattern

def subscribe(self, *args, **kwargs):
"""
Subscribes to one or more given ``channels``.
"""

def _subscriber(channel, handler):
return handler

total_subscriptions =\
len(self.channels.keys()) + len(self.patterns.keys())
self._subscribe(self.channels, 'subscribe', total_subscriptions,
_subscriber, *args, **kwargs)

def _subscribe(self, subscribed_dict, message_type, total_subscriptions,
subscriber, *args, **kwargs):

new_channels = {}
if args:
for arg in args:
new_channels[arg] = subscriber(arg, None)

for channel, handler in iteritems(kwargs):
new_channels[channel] = handler

subscribed_dict.update(new_channels)
self.subscribed = True

for channel in new_channels:
total_subscriptions += 1
self.put(channel, long(total_subscriptions), message_type)

def unsubscribe(self, *args):
"""
Unsubscribes from one or more given ``channels``.
"""
total_subscriptions =\
len(self.channels.keys()) + len(self.patterns.keys())
self._usubscribe(self.channels, 'unsubscribe', total_subscriptions,
*args)

def _usubscribe(self, subscribed_dict, message_type, total_subscriptions,
*args):

if args:
for channel in args:
if channel in subscribed_dict:
total_subscriptions -= 1
self.put(channel, long(total_subscriptions), message_type)
else:
for channel in subscribed_dict:
total_subscriptions -= 1
self.put(channel, long(total_subscriptions), message_type)
subscribed_dict.clear()

if total_subscriptions == 0:
self.subscribed = False

def listen(self):
"""
Listens for queued messages and yields the to the calling process
"""
while self.subscribed:
message = self.get_message()
if message:
yield message

time.sleep(1)

def close(self):
"""
Stops the listen function by calling unsubscribe
"""
self.unsubscribe()
self.punsubscribe()

def get_message(self, ignore_subscribe_messages=False, timeout=0):
"""
Returns the next available message.
"""

try:
message = self._q.get(True, timeout)
return self.handle_message(message, ignore_subscribe_messages)
except Empty:
return None

def handle_message(self, message, ignore_subscribe_messages=False):
"""
Parses a pubsub message. It invokes the handler of a message type,
if the handler is avaialble. If the message is of type ``subscribe``
and ignore_subscribe_messages if True, then it returns None. Otherwise,
it returns the message.
"""
message_type = message['type']
if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
subscribed_dict = None
if message_type == 'punsubscribe':
subscribed_dict = self.patterns
else:
subscribed_dict = self.channels

try:
channel = message['channel'].decode('utf-8')
del subscribed_dict[channel]
except:
pass

if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
handler = None
if message_type == 'pmessage':
pattern = self.patterns.get(message['pattern'], None)
if pattern:
handler = pattern['handler']
else:
handler = self.channels.get(message['channel'], None)
if handler:
handler(message)
return None
else:
# this is a subscribe/unsubscribe message. ignore if we don't
# want them
if ignore_subscribe_messages or self.ignore_subscribe_messages:
return None

return message

0 comments on commit 7e68c2a

Please sign in to comment.