Skip to content

Commit

Permalink
Merge pull request #30 from akgood/master
Browse files Browse the repository at this point in the history
Fixes for handling/reporting lost-connection errors
  • Loading branch information
fiorix committed Sep 5, 2012
2 parents 565d688 + a0c2015 commit 326cb29
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 21 deletions.
90 changes: 90 additions & 0 deletions tests/test_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import txredisapi as redis
from twisted.internet import defer, reactor
from twisted.trial import unittest

redis_host = "localhost"
redis_port = 6379

class TestSubscriberProtocol(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
factory = redis.SubscriberFactory()
factory.continueTrying = False
reactor.connectTCP(redis_host, redis_port, factory)
self.db = yield factory.deferred

@defer.inlineCallbacks
def tearDown(self):
yield self.db.disconnect()

@defer.inlineCallbacks
def testDisconnectErrors(self):
# Slightly dirty, but we want a reference to the actual
# protocol instance
conn = self.db._factory.getConnection

# This should return a deferred from the replyQueue; then
# loseConnection will make it do an errback with a
# ConnectionError instance
d = self.db.subscribe('foo')

conn.transport.loseConnection()
try:
yield d
self.fail()
except redis.ConnectionError:
pass

# This should immediately errback with a ConnectionError
# instance when getConnection finds 0 active instances in the
# factory
try:
yield self.db.subscribe('bar')
self.fail()
except redis.ConnectionError:
pass

# This should immediately raise a ConnectionError instance
# when execute_command() finds that the connection is not
# connected
try:
yield conn.subscribe('baz')
self.fail()
except redis.ConnectionError:
pass

@defer.inlineCallbacks
def testSubscribe(self):
reply = yield self.db.subscribe("test_subscribe1")
self.assertEqual(reply, [u"subscribe", u"test_subscribe1", 1])

reply = yield self.db.subscribe("test_subscribe2")
self.assertEqual(reply, [u"subscribe", u"test_subscribe2", 2])

@defer.inlineCallbacks
def testUnsubscribe(self):
yield self.db.subscribe("test_unsubscribe1")
yield self.db.subscribe("test_unsubscribe2")

reply = yield self.db.unsubscribe("test_unsubscribe1")
self.assertEqual(reply, [u"unsubscribe", u"test_unsubscribe1", 1])
reply = yield self.db.unsubscribe("test_unsubscribe2")
self.assertEqual(reply, [u"unsubscribe", u"test_unsubscribe2", 0])

@defer.inlineCallbacks
def testPSubscribe(self):
reply = yield self.db.psubscribe("test_psubscribe1.*")
self.assertEqual(reply, [u"psubscribe", u"test_psubscribe1.*", 1])

reply = yield self.db.psubscribe("test_psubscribe2.*")
self.assertEqual(reply, [u"psubscribe", u"test_psubscribe2.*", 2])

@defer.inlineCallbacks
def testPUnsubscribe(self):
yield self.db.psubscribe("test_punsubscribe1.*")
yield self.db.psubscribe("test_punsubscribe2.*")

reply = yield self.db.punsubscribe("test_punsubscribe1.*")
self.assertEqual(reply, [u"punsubscribe", u"test_punsubscribe1.*", 1])
reply = yield self.db.punsubscribe("test_punsubscribe2.*")
self.assertEqual(reply, [u"punsubscribe", u"test_punsubscribe2.*", 0])
38 changes: 17 additions & 21 deletions txredisapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def connectionLost(self, why):
self.connected = 0
self.factory.delConnection(self)
LineReceiver.connectionLost(self, why)
while self.replyQueue.pending:
while self.replyQueue.waiting:
self.replyReceived(ConnectionError("Lost connection"))

def lineReceived(self, line):
Expand Down Expand Up @@ -1303,9 +1303,6 @@ class MonitorProtocol(RedisProtocol):
take care with the performance impact: http://redis.io/commands/monitor
"""

def connectionLost(self, why):
pass

def messageReceived(self, message):
pass

Expand All @@ -1320,9 +1317,6 @@ def stop(self):


class SubscriberProtocol(RedisProtocol):
def connectionLost(self, why):
pass

def messageReceived(self, pattern, channel, message):
pass

Expand All @@ -1331,8 +1325,10 @@ def replyReceived(self, reply):
if reply[-3] == u"message":
self.messageReceived(None, *reply[-2:])
else:
self.replyQueue.put(reply[-3])
self.replyQueue.put(reply[-3:])
self.messageReceived(*reply[-3:])
elif isinstance(reply, Exception):
self.replyQueue.put(reply)

def subscribe(self, channels):
if isinstance(channels, (str, unicode)):
Expand All @@ -1354,19 +1350,6 @@ def punsubscribe(self, patterns):
patterns = [patterns]
return self.execute_command("PUNSUBSCRIBE", *patterns)


class SubscriberFactory(protocol.ReconnectingClientFactory):
maxDelay = 120
continueTrying = True
protocol = SubscriberProtocol


class MonitorFactory(protocol.ReconnectingClientFactory):
maxDelay = 120
continueTrying = True
protocol = MonitorProtocol


class ConnectionHandler(object):
def __init__(self, factory):
self._factory = factory
Expand Down Expand Up @@ -1687,6 +1670,19 @@ def getConnection(self):

raise RedisError("In transaction")

class SubscriberFactory(RedisFactory):
protocol = SubscriberProtocol

def __init__(self, isLazy=False, handler=ConnectionHandler):
RedisFactory.__init__(self, None, None, 1, isLazy=isLazy,
handler=handler)

class MonitorFactory(RedisFactory):
protocol = MonitorProtocol

def __init__(self, isLazy=False, handler=ConnectionHandler):
RedisFactory.__init__(self, None, None, 1, isLazy=isLazy,
handler=handler)

def makeConnection(host, port, dbid, poolsize, reconnect, isLazy):
uuid = "%s:%s" % (host, port)
Expand Down

0 comments on commit 326cb29

Please sign in to comment.