Skip to content

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also .

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also .
...
  • 11 commits
  • 9 files changed
  • 1 commit comment
  • 1 contributor
Commits on Apr 02, 2012
John Watson Initial refactor ed41285
Commits on Apr 03, 2012
John Watson Add roundrobin to __init__ import 1f9b042
Commits on Apr 06, 2012
John Watson Some changes c8b93d3
John Watson Add partition router to base eccedb9
John Watson Cluster is required anyway don't need to test b3f1102
John Watson Couple fixes b9df87a
John Watson Name ensure_db_num better 00f9e8d
Commits on Apr 09, 2012
John Watson PASSING ALL THE TESTS! \o 739bd25
Commits on Apr 10, 2012
John Watson More changes per Fluxx's comments 0062905
John Watson Fix test_offers_router_interface test de2c34c
@dctrwatson dctrwatson Merge pull request #5 from disqus/robust-roundrobin
Refactor of routers
825f52e
View
88 nydus/db/base.py
@@ -25,24 +25,24 @@ def create_cluster(settings):
"""
# Pull in our client
if isinstance(settings['engine'], basestring):
- conn = import_string(settings['engine'])
+ Conn = import_string(settings['engine'])
else:
- conn = settings['engine']
+ Conn = settings['engine']
# Pull in our router
router = settings.get('router')
if isinstance(router, basestring):
- router = import_string(router)()
+ router = import_string(router)
elif router:
- router = router()
+ router = router
else:
- router = BaseRouter()
+ router = BaseRouter
# Build the connection cluster
return Cluster(
router=router,
hosts=dict(
- (conn_number, conn(num=conn_number, **host_settings))
+ (conn_number, Conn(num=conn_number, **host_settings))
for conn_number, host_settings
in settings['hosts'].iteritems()
),
@@ -53,10 +53,12 @@ class Cluster(object):
"""
Holds a cluster of connections.
"""
+ class MaxRetriesExceededError(Exception):
+ pass
- def __init__(self, hosts, router=None, max_connection_retries=20):
+ def __init__(self, hosts, router=BaseRouter, max_connection_retries=20):
self.hosts = hosts
- self.router = router
+ self.router = router()
self.max_connection_retries = max_connection_retries
def __len__(self):
@@ -76,22 +78,28 @@ def __iter__(self):
yield name
def _execute(self, attr, args, kwargs):
- db_nums = self._db_nums_for(*args, **kwargs)
+ connections = self._connections_for(attr, *args, **kwargs)
+
+ results = []
+ for conn in connections:
+ for retry in xrange(self.max_connection_retries):
+ try:
+ results.append(getattr(conn, attr)(*args, **kwargs))
+ except tuple(conn.retryable_exceptions), e:
+ if not self.router.retryable:
+ raise e
+ elif retry == self.max_connection_retries - 1:
+ raise self.MaxRetriesExceededError(e)
+ else:
+ conn = self._connections_for(attr, retry_for=conn.num, *args, **kwargs)[0]
+ else:
+ break
- if self.router and len(db_nums) is 1 and self.router.retryable:
- # The router supports retryable commands, so we want to run a
- # separate algorithm for how we get connections to run commands on
- # and then possibly retry
- return self._retryable_execute(db_nums, attr, *args, **kwargs)
+ # If we only had one db to query, we simply return that res
+ if len(results) == 1:
+ return results[0]
else:
- connections = self._connections_for(*args, **kwargs)
- results = [getattr(conn, attr)(*args, **kwargs) for conn in connections]
-
- # If we only had one db to query, we simply return that res
- if len(results) == 1:
- return results[0]
- else:
- return results
+ return results
def disconnect(self):
"""Disconnects all connections in cluster"""
@@ -106,7 +114,7 @@ def get_conn(self, *args, **kwargs):
during all steps of the process. An example of this would be
Redis pipelines.
"""
- connections = self._connections_for(*args, **kwargs)
+ connections = self._connections_for('get_conn', *args, **kwargs)
if len(connections) is 1:
return connections[0]
@@ -116,34 +124,8 @@ def get_conn(self, *args, **kwargs):
def map(self, workers=None):
return DistributedContextManager(self, workers)
- def _retryable_execute(self, db_nums, attr, *args, **kwargs):
- retries = 0
-
- while retries <= self.max_connection_retries:
- if len(db_nums) > 1:
- raise Exception('Retryable router returned multiple DBs')
- else:
- connection = self[db_nums[0]]
-
- try:
- return getattr(connection, attr)(*args, **kwargs)
- except tuple(connection.retryable_exceptions):
- # We had a failure, so get a new db_num and try again, noting
- # the DB number that just failed, so the backend can mark it as
- # down
- db_nums = self._db_nums_for(retry_for=db_nums[0], *args, **kwargs)
- retries += 1
- else:
- raise Exception('Maximum amount of connection retries exceeded')
-
- def _db_nums_for(self, *args, **kwargs):
- if self.router:
- return self.router.get_db(self, 'get_conn', *args, **kwargs)
- else:
- return range(len(self))
-
- def _connections_for(self, *args, **kwargs):
- return [self[n] for n in self._db_nums_for(*args, **kwargs)]
+ def _connections_for(self, attr, *args, **kwargs):
+ return [self[n] for n in self.router.get_dbs(self, attr, *args, **kwargs)]
class CallProxy(object):
@@ -325,9 +307,9 @@ def _execute(self):
command_map[cmd_ident] = command
if self._cluster.router:
- db_nums = self._cluster.router.get_db(self._cluster, command._attr, *command._args, **command._kwargs)
+ db_nums = self._cluster.router.get_dbs(self._cluster, command._attr, *command._args, **command._kwargs)
else:
- db_nums = range(len(self._cluster))
+ db_nums = self._cluster.keys()
# The number of commands is based on the total number of executable commands
num_commands += len(db_nums)
View
2 nydus/db/routers/__init__.py
@@ -6,4 +6,4 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from .base import BaseRouter
+from .base import BaseRouter, RoundRobinRouter, PartitionRouter
View
168 nydus/db/routers/base.py
@@ -5,9 +5,13 @@
:copyright: (c) 2011 DISQUS.
:license: Apache License 2.0, see LICENSE for more details.
"""
+import time
+from binascii import crc32
+from collections import defaultdict
+from itertools import cycle
-__all__ = ('BaseRouter',)
+__all__ = ('BaseRouter', 'RoundRobinRouter', 'PartitionRouter')
class BaseRouter(object):
@@ -16,9 +20,163 @@ class BaseRouter(object):
"""
retryable = False
- def get_db(self, cluster, func, *args, **kwargs):
+ class UnableToSetupRouter(Exception):
+ pass
+
+ def __init__(self, *args, **kwargs):
+ self._ready = False
+
+ def get_dbs(self, cluster, attr, key=None, *args, **kwargs):
+ """
+ Perform setup and routing
+ Always return an iterable
+ Do not overload this method
+ """
+ if not self._ready:
+ if not self.setup_router(cluster, *args, **kwargs):
+ raise self.UnableToSetupRouter()
+
+ key = self._pre_routing(cluster, attr, key, *args, **kwargs)
+
+ if not key:
+ return cluster.hosts.keys()
+
+ try:
+ db_nums = self._route(cluster, attr, key, *args, **kwargs)
+ except Exception, e:
+ self._handle_exception(e)
+ db_nums = []
+
+ return self._post_routing(cluster, attr, key, db_nums, *args, **kwargs)
+
+ # Backwards compatibilty
+ get_db = get_dbs
+
+ def setup_router(self, cluster, *args, **kwargs):
+ """
+ Call method to perform any setup
+ """
+ self._ready = self._setup_router(cluster, *args, **kwargs)
+
+ return self._ready
+
+ def _setup_router(self, cluster, *args, **kwargs):
"""
- Return the first entry in the cluster
- The return value must be iterable
+ Perform any initialization for the router
+ Returns False if setup could not be completed
"""
- return range(len(cluster))
+ return True
+
+ def _pre_routing(self, cluster, attr, key, *args, **kwargs):
+ """
+ Perform any prerouting with this method and return the key
+ """
+ return key
+
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ """
+ Perform routing and return db_nums
+ """
+ return [cluster.hosts.keys()[0]]
+
+ def _post_routing(self, cluster, attr, key, db_nums, *args, **kwargs):
+ """
+ Perform any postrouting actions and return db_nums
+ """
+ return db_nums
+
+ def _handle_exception(self, e):
+ """
+ Handle/transform exceptions and return it
+ """
+ raise e
+
+
+class RoundRobinRouter(BaseRouter):
+ """
+ Basic retry router that performs round robin
+ """
+
+ # Raised if all hosts in the hash have been marked as down
+ class HostListExhausted(Exception):
+ pass
+
+ class InvalidDBNum(Exception):
+ pass
+
+ # If this router can be retried on if a particular db index it gave out did
+ # not work
+ retryable = True
+
+ # How many requests to serve in a situation when a host is down before
+ # the down hosts are retried
+ attempt_reconnect_threshold = 100000
+
+ # Retry a down connection after this timeout
+ retry_timeout = 30
+
+ def __init__(self, *args, **kwargs):
+ self._get_db_attempts = 0
+ self._down_connections = {}
+
+ super(RoundRobinRouter,self).__init__(*args, **kwargs)
+
+ @classmethod
+ def ensure_db_num(cls, db_num):
+ try:
+ return int(db_num)
+ except ValueError:
+ raise cls.InvalidDBNum()
+
+ def flush_down_connections(self):
+ self._get_db_attempts = 0
+ self._down_connections = {}
+
+ def mark_connection_down(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._down_connections[db_num] = time.time()
+
+ def mark_connection_up(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._down_connections.pop(db_num, None)
+
+ def _setup_router(self, cluster, *args, **kwargs):
+ self._hosts_cycler = cycle(cluster.hosts.keys())
+
+ return True
+
+ def _pre_routing(self, cluster, attr, key, *args, **kwargs):
+ self._get_db_attempts += 1
+
+ if self._get_db_attempts > self.attempt_reconnect_threshold:
+ self.flush_down_connections()
+
+ if 'retry_for' in kwargs:
+ self.mark_connection_down(kwargs['retry_for'])
+
+ return key
+
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ now = time.time()
+
+ for i in xrange(len(cluster)):
+ db_num = self._hosts_cycler.next()
+
+ marked_down_at = self._down_connections.get(db_num, False)
+
+ if not marked_down_at or (marked_down_at + self.retry_timeout <= now):
+ return [db_num]
+ else:
+ raise self.HostListExhausted()
+
+ def _post_routing(self, cluster, attr, key, db_nums, *args, **kwargs):
+ if db_nums:
+ self.mark_connection_up(db_nums[0])
+
+ return db_nums
+
+
+class PartitionRouter(BaseRouter):
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ return [crc32(str(key)) % len(cluster)]
+
View
99 nydus/db/routers/redis.py
@@ -6,103 +6,50 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from binascii import crc32
-from itertools import cycle
-
-from nydus.db.routers import BaseRouter
+from nydus.db.routers import RoundRobinRouter
from nydus.contrib.ketama import Ketama
-class PartitionRouter(BaseRouter):
- def get_db(self, cluster, func, key=None, *args, **kwargs):
- # Assume first argument is a key
- if not key:
- return range(len(cluster))
- return [crc32(str(key)) % len(cluster)]
-
-
-class RoundRobinRouter(BaseRouter):
-
- def _get_db__round_robin(self, cluster):
- c = cycle(range(len(cluster)))
- for x in c:
- yield x
-
- def get_db(self, cluster, *args, **kwargs):
- if not cluster:
- return []
- if not hasattr(self, 'cycler'):
- self.cycler = self._get_db__round_robin(cluster)
- return [self.cycler.next()]
-
-
-class ConsistentHashingRouter(BaseRouter):
+class ConsistentHashingRouter(RoundRobinRouter):
'''
Router that returns host number based on a consistent hashing algorithm.
The consistent hashing algorithm only works if a key argument is provided.
If a key is not provided, then all hosts are returned.
'''
- # Raised if all hosts in the hash have been marked as down
- class HostListExhaused(Exception):
- pass
-
- # If this router can be retried on if a particular db index it gave out did
- # not work
- retryable = True
-
- # How many requests to serve in a situation when a host is down before
- # the down hosts are retried
- attempt_reconnect_threshold = 100000
-
- def __init__(self):
- self._get_db_attempts = 0
- self._down_connections = set()
-
- # There is one instance of this class that lives inside the Cluster object
- def get_db(self, cluster, func, key=None, *args, **kwargs):
- self._setup_hash_and_connections(cluster, *args, **kwargs)
-
- if not cluster:
- return []
- elif not key:
- return range(len(cluster))
- else:
- return self._host_indexes_for(key, cluster)
+ def __init__(self, *args, **kwargs):
+ self._db_num_id_map = {}
+ super(ConsistentHashingRouter, self).__init__(*args, **kwargs)
def flush_down_connections(self):
- for connection in self._down_connections:
- self._hash.add_node(connection.identifier)
+ for db_num in self._down_connections:
+ self._hash.add_node(self._db_num_id_map[db_num])
- self._down_connections = set()
+ super(ConsistentHashingRouter, self).flush_down_connections()
- def _setup_hash_and_connections(self, cluster, *args, **kwargs):
- # Create the hash if it doesn't exist yet
- if not hasattr(self, '_hash'):
- strings = [h.identifier for (i, h) in cluster.hosts.items()]
- self._hash = Ketama(strings)
+ def mark_connection_down(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._hash.remove_node(self._db_num_id_map[db_num])
- self._handle_host_retries(cluster, retry_for=kwargs.get('retry_for'))
+ super(ConsistentHashingRouter, self).mark_connection_down(db_num)
- def _handle_host_retries(self, cluster, retry_for):
- self._get_db_attempts += 1
+ def mark_conenction_up(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._hash.add_node(self._db_num_id_map[db_num])
- if self._get_db_attempts > self.attempt_reconnect_threshold:
- self.flush_down_connections()
- self._get_db_attempts = 0
+ super(ConsistentHashingRouter, self).mark_connection_up(db_num)
- if retry_for is not None:
- self._mark_connection_as_down(cluster[retry_for])
+ def _setup_router(self, cluster, *args, **kwargs):
+ self._db_num_id_map = dict([(db_num, host.identifier) for db_num, host in cluster.hosts.iteritems()])
+ self._hash = Ketama(self._db_num_id_map.values())
- def _mark_connection_as_down(self, connection):
- self._hash.remove_node(connection.identifier)
- self._down_connections.add(connection)
+ return True
- def _host_indexes_for(self, key, cluster):
+ def _route(self, cluster, attr, key, *args, **kwargs):
found = self._hash.get_node(key)
if not found and len(self._down_connections) > 0:
- raise self.HostListExhaused
+ raise self.HostListExhausted()
- return [i for (i, h) in cluster.hosts.items()
+ return [i for i, h in cluster.hosts.iteritems()
if h.identifier == found]
View
7 tests/nydus/contrib/django/tests.py
@@ -68,13 +68,6 @@ def test_proxy(self):
cursor = self.db.execute('SELECT 1')
self.assertEquals(cursor.fetchone(), (1,))
- def test_with_cluster(self):
- p = Cluster(
- hosts={0: self.db},
- )
- cursor = p.execute('SELECT 1')
- self.assertEquals(cursor.fetchone(), (1,))
-
def test_provides_identififer(self):
self.assertEqual(
"django.db.backends.sqlite3NAME=:memory: PORT=None HOST=None USER=None TEST_NAME=None PASSWORD=None OPTIONS={}",
View
7 tests/nydus/db/backends/redis/tests.py
@@ -32,7 +32,7 @@ def test_provides_identifier(self):
def test_pipelined_map(self):
redis = create_cluster({
'engine': 'nydus.db.backends.redis.Redis',
- 'router': 'nydus.db.routers.redis.PartitionRouter',
+ 'router': 'nydus.db.routers.PartitionRouter',
'hosts': {
0: {'db': 5},
1: {'db': 6},
@@ -59,12 +59,13 @@ def test_client_instantiates_with_kwargs(self, RedisClient):
def test_map_does_pipeline(self, RedisClient):
redis = create_cluster({
'engine': 'nydus.db.backends.redis.Redis',
- 'router': 'nydus.db.routers.redis.PartitionRouter',
+ 'router': 'nydus.db.routers.PartitionRouter',
'hosts': {
0: {'db': 0},
1: {'db': 1},
}
})
+
with redis.map() as conn:
conn.set('a', 0)
conn.set('d', 1)
@@ -86,7 +87,7 @@ def test_map_does_pipeline(self, RedisClient):
def test_map_only_runs_on_required_nodes(self, RedisClient):
redis = create_cluster({
'engine': 'nydus.db.backends.redis.Redis',
- 'router': 'nydus.db.routers.redis.PartitionRouter',
+ 'router': 'nydus.db.routers.PartitionRouter',
'hosts': {
0: {'db': 0},
1: {'db': 1},
View
4 tests/nydus/db/backends/thoonk/tests.py
@@ -62,7 +62,7 @@ def test_job_with_ConsistentHashingRouter(self):
self.assertTrue(jid_found)
def test_job_with_RoundRobinRouter(self):
- pubsub = self.get_cluster('nydus.db.routers.redis.RoundRobinRouter')
+ pubsub = self.get_cluster('nydus.db.routers.RoundRobinRouter')
jobs = {}
size = 20
@@ -89,7 +89,7 @@ def test_job_with_RoundRobinRouter(self):
self.assertEqual(len(jobs), 5)
for x in range(len(pubsub)):
- ps = pubsub.get_conn()
+ ps = pubsub.get_conn('testjob')
jps = ps.job('testjob')
self.assertEqual(jps.get_ids(), [])
View
27 tests/nydus/db/connections/tests.py
@@ -26,8 +26,7 @@ def foo(self, *args, **kwargs):
class DummyRouter(BaseRouter):
- def get_db(self, cluster, func, key=None, *args, **kwargs):
- # Assume first argument is a key
+ def get_dbs(self, cluster, attr, key=None, *args, **kwargs):
if key == 'foo':
return [1]
return [0]
@@ -70,7 +69,7 @@ def test_with_router(self):
c2 = DummyConnection(num=1, resp='bar')
# test dummy router
- r = DummyRouter()
+ r = DummyRouter
p = Cluster(
hosts={0: c, 1: c2},
router=r,
@@ -83,14 +82,14 @@ def test_with_router(self):
hosts={0: c, 1: c2},
)
self.assertEquals(p.foo(), ['foo', 'bar'])
- self.assertEquals(p.foo('foo'), ['foo', 'bar'])
+ self.assertEquals(p.foo('foo'), 'foo')
def test_get_conn(self):
c = DummyConnection(alias='foo', num=0, resp='foo')
c2 = DummyConnection(alias='foo', num=1, resp='bar')
# test dummy router
- r = DummyRouter()
+ r = DummyRouter
p = Cluster(
hosts={0: c, 1: c2},
router=r,
@@ -103,14 +102,14 @@ def test_get_conn(self):
hosts={0: c, 1: c2},
)
self.assertEquals(p.get_conn(), [c, c2])
- self.assertEquals(p.get_conn('foo'), [c, c2])
+ self.assertEquals(p.get_conn('foo'), c)
def test_map(self):
c = DummyConnection(num=0, resp='foo')
c2 = DummyConnection(num=1, resp='bar')
# test dummy router
- r = DummyRouter()
+ r = DummyRouter
p = Cluster(
hosts={0: c, 1: c2},
router=r,
@@ -135,7 +134,7 @@ def test_map(self):
self.assertEquals(bar, None)
self.assertEquals(foo, ['foo', 'bar'])
- self.assertEquals(bar, ['foo', 'bar'])
+ self.assertEquals(bar, 'foo')
class FlakeyConnection(DummyConnection):
@@ -158,7 +157,7 @@ def __init__(self):
self.key_args_seen = []
super(RetryableRouter, self).__init__()
- def get_db(self, cluster, func, key=None, *args, **kwargs):
+ def get_dbs(self, cluster, func, key=None, *args, **kwargs):
self.kwargs_seen.append(kwargs)
self.key_args_seen.append(key)
return [0]
@@ -171,7 +170,7 @@ def __init__(self):
self.returned = False
super(InconsistentRouter, self).__init__()
- def get_db(self, cluster, func, key=None, *args, **kwargs):
+ def get_dbs(self, cluster, func, key=None, *args, **kwargs):
if self.returned:
return range(5)
else:
@@ -210,12 +209,8 @@ def test_retry_router_when_receives_error(self):
def test_protection_from_infinate_loops(self):
cluster = self.build_cluster(connection=ScumbagConnection)
- self.assertRaises(Exception, cluster.foo)
-
- def test_retryable_router_returning_multiple_dbs_raises_ecxeption(self):
- cluster = self.build_cluster(router=InconsistentRouter, connection=ScumbagConnection)
- self.assertRaisesRegexp(Exception, 'returned multiple DBs',
- cluster.foo)
+ with self.assertRaises(Exception):
+ cluster.foo()
class EventualCommandTest(BaseTest):
View
254 tests/nydus/db/routers/tests.py
@@ -1,10 +1,17 @@
from __future__ import absolute_import
+import time
+
+from collections import Iterable
+from inspect import getargspec
+
+from mock import patch
+
from tests import BaseTest
from nydus.db.base import Cluster
from nydus.db.backends import BaseConnection
+from nydus.db.routers import BaseRouter, RoundRobinRouter, PartitionRouter
from nydus.db.routers.redis import ConsistentHashingRouter
-from nydus.db.routers.redis import RoundRobinRouter
class DummyConnection(BaseConnection):
@@ -18,119 +25,234 @@ def identifier(self):
return "%s:%s" % (self.host, self.i)
-class RoundRobinRouterTest(BaseTest):
+class BaseRouterTest(BaseTest):
+ Router = BaseRouter
+ class TestException(Exception): pass
def setUp(self):
- self.router = RoundRobinRouter()
+ self.router = self.Router()
self.hosts = dict((i, DummyConnection(i)) for i in range(5))
- self.cluster = Cluster(router=self.router, hosts=self.hosts)
+ self.cluster = Cluster(router=self.Router, hosts=self.hosts)
- def get_db(self, *args, **kwargs):
+ def get_dbs(self, *args, **kwargs):
kwargs.setdefault('cluster', self.cluster)
- return self.router.get_db(*args, **kwargs)
+ return self.router.get_dbs(*args, **kwargs)
+ def test_not_ready(self):
+ self.assertTrue(not self.router._ready)
-class ConsistentHashingRouterTest(BaseTest):
+ def test_get_dbs_iterable(self):
+ db_nums = self.get_dbs(attr='test', key='foo')
+ self.assertIsInstance(db_nums, Iterable)
- def setUp(self):
- self.router = ConsistentHashingRouter()
- self.hosts = dict((i, DummyConnection(i)) for i in range(5))
- self.cluster = Cluster(router=self.router, hosts=self.hosts)
+ def test_get_dbs_unabletosetuproute(self):
+ with patch.object(self.router, '_setup_router', return_value=False):
+ with self.assertRaises(BaseRouter.UnableToSetupRouter):
+ self.get_dbs(attr='test', key='foo')
+
+ def test_setup_router_returns_true(self):
+ self.assertTrue(self.router.setup_router(self.cluster))
- def get_db(self, **kwargs):
- kwargs.setdefault('cluster', self.cluster)
- return self.router.get_db(func='info', **kwargs)
+ def test_offers_router_interface(self):
+ self.assertTrue(callable(self.router.get_dbs))
+ dbargs, _, _, dbdefaults = getargspec(self.router.get_dbs)
+ self.assertTrue(set(dbargs) >= set(['self', 'cluster', 'attr', 'key']))
+ self.assertIsNone(dbdefaults[0])
+
+ self.assertTrue(callable(self.router.setup_router))
+ setupargs, _, _, setupdefaults = getargspec(self.router.setup_router)
+ self.assertTrue(set(setupargs) >= set(['self', 'cluster']))
+ self.assertIsNone(setupdefaults)
+
+ def test_returns_whole_cluster_without_key(self):
+ self.assertEquals(self.hosts.keys(), self.get_dbs(attr='test'))
+
+ def test_returns_sequence_with_one_item_when_given_key(self):
+ self.assertEqual(len(self.get_dbs(attr='test', key='foo')), 1)
+
+ def test_get_dbs_handles_exception(self):
+ with patch.object(self.router, '_route') as _route:
+ with patch.object(self.router, '_handle_exception') as _handle_exception:
+ _route.side_effect = self.TestException()
+
+ self.get_dbs(attr='test', key='foo')
+ self.assertTrue(_handle_exception.called)
-class RoundRobinTest(BaseTest):
+
+class BaseBaseRouterTest(BaseRouterTest):
+ def test__setup_router_returns_true(self):
+ self.assertTrue(self.router._setup_router(self.cluster))
+
+ def test__pre_routing_returns_key(self):
+ key = 'foo'
+
+ self.assertEqual(key, self.router._pre_routing(self.cluster, 'foo', key))
+
+ def test__route_returns_first_db_num(self):
+ self.assertEqual(self.cluster.hosts.keys()[0], self.router._route(self.cluster, 'test', 'foo')[0])
+
+ def test__post_routing_returns_db_nums(self):
+ db_nums = self.hosts.keys()
+
+ self.assertEqual(db_nums, self.router._post_routing(self.cluster, 'test', 'foo', db_nums))
+
+ def test__handle_exception_raises_same_exception(self):
+ e = self.TestException()
+
+ with self.assertRaises(self.TestException):
+ self.router._handle_exception(e)
+
+
+class BaseRoundRobinRouterTest(BaseRouterTest):
+ Router = RoundRobinRouter
def setUp(self):
- self.router = RoundRobinRouter()
- self.hosts = dict((i, DummyConnection(i)) for i in range(5))
- self.cluster = Cluster(router=self.router, hosts=self.hosts)
+ super(BaseRoundRobinRouterTest, self).setUp()
+ assert self.router._setup_router(self.cluster)
- def get_db(self, *args, **kwargs):
- kwargs.setdefault('cluster', self.cluster)
- return self.router.get_db(*args, **kwargs)
+ def test_ensure_db_num(self):
+ db_num = 0
+ s_db_num = str(db_num)
- def test_cluster_of_zero_returns_zero(self):
- self.cluster.hosts = dict()
- self.assertEquals([], self.get_db())
+ self.assertEqual(self.router.ensure_db_num(db_num), db_num)
+ self.assertEqual(self.router.ensure_db_num(s_db_num), db_num)
- def test_cluster_of_one_returns_one(self):
- self.cluster.hosts = {0: DummyConnection('foo')}
- self.assertEquals([0], self.get_db())
+ def test_esnure_db_num_raises(self):
+ with self.assertRaises(RoundRobinRouter.InvalidDBNum):
+ self.router.ensure_db_num('a')
- def test_multi_node_cluster_returns_correct_host(self):
- self.cluster.hosts = {0: DummyConnection('foo'), 1: DummyConnection('bar')}
- self.assertEquals([[0], [1], [0], [1]], [self.get_db(), self.get_db(), self.get_db(), self.get_db(), ])
+ def test_flush_down_connections(self):
+ self.router._get_db_attempts = 9001
+ self._down_connections = {0: time.time()}
+ self.router.flush_down_connections()
-class InterfaceTest(ConsistentHashingRouterTest):
+ self.assertEqual(self.router._get_db_attempts, 0)
+ self.assertEqual(self.router._down_connections, {})
- def test_offers_router_interface(self):
- self.assertTrue(callable(self.router.get_db))
+ def test_mark_connection_down(self):
+ db_num = 0
- def test_get_db_returns_itereable(self):
- iter(self.get_db())
+ self.router.mark_connection_down(db_num)
- def test_returns_whole_cluster_without_key(self):
- self.assertEquals(range(5), self.get_db())
+ self.assertAlmostEqual(self.router._down_connections[db_num], time.time(), delta=10)
- def test_returns_sequence_with_one_item_when_given_key(self):
- self.assert_(len(self.get_db(key='foo')) is 1)
+ def test_mark_connection_up(self):
+ db_num = 0
+
+ self.router.mark_connection_down(db_num)
+
+ self.assertIn(db_num, self.router._down_connections)
+
+ self.router.mark_connection_up(db_num)
+
+ self.assertNotIn(db_num, self.router._down_connections)
+
+ def test__pre_routing_updates__get_db_attempts(self):
+ self.router._pre_routing(self.cluster, 'test', 'foo')
+
+ self.assertEqual(self.router._get_db_attempts, 1)
+
+ @patch('nydus.db.routers.RoundRobinRouter.flush_down_connections')
+ def test__pre_routing_flush_down_connections(self, _flush_down_connections):
+ self.router._get_db_attempts = RoundRobinRouter.attempt_reconnect_threshold + 1
+
+ self.router._pre_routing(self.cluster, 'test', 'foo')
+
+ self.assertTrue(_flush_down_connections.called)
+
+ @patch('nydus.db.routers.RoundRobinRouter.mark_connection_down')
+ def test__pre_routing_retry_for(self, _mark_connection_down):
+ db_num = 0
+
+ self.router._pre_routing(self.cluster, 'test', 'foo', retry_for=db_num)
+
+ _mark_connection_down.assert_called_with(db_num)
+ @patch('nydus.db.routers.RoundRobinRouter.mark_connection_up')
+ def test__post_routing_mark_connection_up(self, _mark_connection_up):
+ db_nums = [0]
-class HashingTest(ConsistentHashingRouterTest):
+ self.assertEqual(self.router._post_routing(self.cluster, 'test', 'foo', db_nums), db_nums)
+ _mark_connection_up.assert_called_with(db_nums[0])
- def get_db(self, **kwargs):
- kwargs['key'] = 'foo'
- return super(HashingTest, self).get_db(**kwargs)
- def test_cluster_of_zero_returns_zero(self):
- self.cluster.hosts = dict()
- self.assertEquals([], self.get_db())
+class RoundRobinRouterTest(BaseRoundRobinRouterTest):
+ def test__setup_router(self):
+ self.assertTrue(self.router._setup_router(self.cluster))
+ self.assertIsInstance(self.router._hosts_cycler, Iterable)
- def test_cluster_of_one_returns_one(self):
- self.cluster.hosts = dict(only_key=DummyConnection('foo'))
- self.assertEquals(['only_key'], self.get_db())
+ def test__route_cycles_through_keys(self):
+ db_nums = self.hosts.keys() * 2
+ results = [self.router._route(self.cluster, 'test', 'foo')[0] for _ in db_nums]
- def test_multi_node_cluster_returns_correct_host(self):
- self.assertEquals([2], self.get_db())
+ self.assertEqual(results, db_nums)
+ def test__route_retry(self):
+ self.router.retry_timeout = 0
-class RetryableTest(HashingTest):
+ db_num = 0
- def test_attempt_reconnect_threshold_is_set(self):
- self.assertEqual(self.router.attempt_reconnect_threshold, 100000)
+ self.router.mark_connection_down(db_num)
+
+ db_nums = self.router._route(self.cluster, 'test', 'foo')
+
+ self.assertEqual(db_nums, [db_num])
+
+ def test__route_skip_down(self):
+ db_num = 0
+
+ self.router.mark_connection_down(db_num)
+
+ db_nums = self.router._route(self.cluster, 'test', 'foo')
+
+ self.assertNotEqual(db_nums, [db_num])
+ self.assertEqual(db_nums, [db_num+1])
+
+ def test__route_hostlistexhausted(self):
+ [self.router.mark_connection_down(db_num) for db_num in self.hosts.keys()]
+
+ with self.assertRaises(RoundRobinRouter.HostListExhausted):
+ self.router._route(self.cluster, 'test', 'foo')
+
+
+
+class ConsistentHashingRouterTest(BaseRoundRobinRouterTest):
+ Router = ConsistentHashingRouter
+
+ def get_dbs(self, *args, **kwargs):
+ kwargs['attr'] = 'test'
+ return super(ConsistentHashingRouterTest, self).get_dbs(*args, **kwargs)
def test_retry_gives_next_host_if_primary_is_offline(self):
- self.assertEquals([2], self.get_db())
- self.assertEquals([4], self.get_db(retry_for=2))
+ self.assertEquals([2], self.get_dbs(key='foo'))
+ self.assertEquals([4], self.get_dbs(key='foo', retry_for=2))
def test_retry_host_change_is_sticky(self):
- self.assertEquals([2], self.get_db())
- self.assertEquals([4], self.get_db(retry_for=2))
+ self.assertEquals([2], self.get_dbs(key='foo'))
+ self.assertEquals([4], self.get_dbs(key='foo', retry_for=2))
- self.assertEquals([4], self.get_db())
+ self.assertEquals([4], self.get_dbs(key='foo'))
def test_adds_back_down_host_once_attempt_reconnect_threshold_is_passed(self):
ConsistentHashingRouter.attempt_reconnect_threshold = 3
- self.assertEquals([2], self.get_db())
- self.assertEquals([4], self.get_db(retry_for=2))
- self.assertEquals([4], self.get_db())
+ self.assertEquals([2], self.get_dbs(key='foo'))
+ self.assertEquals([4], self.get_dbs(key='foo', retry_for=2))
+ self.assertEquals([4], self.get_dbs(key='foo'))
# Router should add host 1 back to the pool now
- self.assertEquals([2], self.get_db())
+ self.assertEquals([2], self.get_dbs(key='foo'))
ConsistentHashingRouter.attempt_reconnect_threshold = 100000
def test_raises_host_list_exhaused_if_no_host_can_be_found(self):
# Kill the first 4
- [self.get_db(retry_for=i) for i in range(4)]
+ [self.get_dbs(retry_for=i) for i in range(4)]
# And the 5th should raise an error
self.assertRaises(
- ConsistentHashingRouter.HostListExhaused,
- self.get_db, **dict(retry_for=4))
+ ConsistentHashingRouter.HostListExhausted,
+ self.get_dbs, **dict(key='foo', retry_for=4))
+

Showing you all comments on commits in this comparison.

@Fluxx
Fluxx commented on c8b93d3 Apr 6, 2012

I might rename this to ensure_valid_db_num since it both verifies the DB number and converts it to an int.

Something went wrong with that request. Please try again.