Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge pull request #5 from disqus/robust-roundrobin

Refactor of routers
  • Loading branch information...
commit 825f52e15aa451636fca76dfdc7f244c5704b45b 2 parents b0edafc + de2c34c
@dctrwatson dctrwatson authored
View
88 nydus/db/base.py
@@ -25,24 +25,24 @@ def create_cluster(settings):
"""
# Pull in our client
if isinstance(settings['engine'], basestring):
- conn = import_string(settings['engine'])
+ Conn = import_string(settings['engine'])
else:
- conn = settings['engine']
+ Conn = settings['engine']
# Pull in our router
router = settings.get('router')
if isinstance(router, basestring):
- router = import_string(router)()
+ router = import_string(router)
elif router:
- router = router()
+ router = router
else:
- router = BaseRouter()
+ router = BaseRouter
# Build the connection cluster
return Cluster(
router=router,
hosts=dict(
- (conn_number, conn(num=conn_number, **host_settings))
+ (conn_number, Conn(num=conn_number, **host_settings))
for conn_number, host_settings
in settings['hosts'].iteritems()
),
@@ -53,10 +53,12 @@ class Cluster(object):
"""
Holds a cluster of connections.
"""
+ class MaxRetriesExceededError(Exception):
+ pass
- def __init__(self, hosts, router=None, max_connection_retries=20):
+ def __init__(self, hosts, router=BaseRouter, max_connection_retries=20):
self.hosts = hosts
- self.router = router
+ self.router = router()
self.max_connection_retries = max_connection_retries
def __len__(self):
@@ -76,22 +78,28 @@ def __iter__(self):
yield name
def _execute(self, attr, args, kwargs):
- db_nums = self._db_nums_for(*args, **kwargs)
+ connections = self._connections_for(attr, *args, **kwargs)
+
+ results = []
+ for conn in connections:
+ for retry in xrange(self.max_connection_retries):
+ try:
+ results.append(getattr(conn, attr)(*args, **kwargs))
+ except tuple(conn.retryable_exceptions), e:
+ if not self.router.retryable:
+ raise e
+ elif retry == self.max_connection_retries - 1:
+ raise self.MaxRetriesExceededError(e)
+ else:
+ conn = self._connections_for(attr, retry_for=conn.num, *args, **kwargs)[0]
+ else:
+ break
- if self.router and len(db_nums) is 1 and self.router.retryable:
- # The router supports retryable commands, so we want to run a
- # separate algorithm for how we get connections to run commands on
- # and then possibly retry
- return self._retryable_execute(db_nums, attr, *args, **kwargs)
+ # If we only had one db to query, we simply return that res
+ if len(results) == 1:
+ return results[0]
else:
- connections = self._connections_for(*args, **kwargs)
- results = [getattr(conn, attr)(*args, **kwargs) for conn in connections]
-
- # If we only had one db to query, we simply return that res
- if len(results) == 1:
- return results[0]
- else:
- return results
+ return results
def disconnect(self):
"""Disconnects all connections in cluster"""
@@ -106,7 +114,7 @@ def get_conn(self, *args, **kwargs):
during all steps of the process. An example of this would be
Redis pipelines.
"""
- connections = self._connections_for(*args, **kwargs)
+ connections = self._connections_for('get_conn', *args, **kwargs)
if len(connections) is 1:
return connections[0]
@@ -116,34 +124,8 @@ def get_conn(self, *args, **kwargs):
def map(self, workers=None):
return DistributedContextManager(self, workers)
- def _retryable_execute(self, db_nums, attr, *args, **kwargs):
- retries = 0
-
- while retries <= self.max_connection_retries:
- if len(db_nums) > 1:
- raise Exception('Retryable router returned multiple DBs')
- else:
- connection = self[db_nums[0]]
-
- try:
- return getattr(connection, attr)(*args, **kwargs)
- except tuple(connection.retryable_exceptions):
- # We had a failure, so get a new db_num and try again, noting
- # the DB number that just failed, so the backend can mark it as
- # down
- db_nums = self._db_nums_for(retry_for=db_nums[0], *args, **kwargs)
- retries += 1
- else:
- raise Exception('Maximum amount of connection retries exceeded')
-
- def _db_nums_for(self, *args, **kwargs):
- if self.router:
- return self.router.get_db(self, 'get_conn', *args, **kwargs)
- else:
- return range(len(self))
-
- def _connections_for(self, *args, **kwargs):
- return [self[n] for n in self._db_nums_for(*args, **kwargs)]
+ def _connections_for(self, attr, *args, **kwargs):
+ return [self[n] for n in self.router.get_dbs(self, attr, *args, **kwargs)]
class CallProxy(object):
@@ -325,9 +307,9 @@ def _execute(self):
command_map[cmd_ident] = command
if self._cluster.router:
- db_nums = self._cluster.router.get_db(self._cluster, command._attr, *command._args, **command._kwargs)
+ db_nums = self._cluster.router.get_dbs(self._cluster, command._attr, *command._args, **command._kwargs)
else:
- db_nums = range(len(self._cluster))
+ db_nums = self._cluster.keys()
# The number of commands is based on the total number of executable commands
num_commands += len(db_nums)
View
2  nydus/db/routers/__init__.py
@@ -6,4 +6,4 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from .base import BaseRouter
+from .base import BaseRouter, RoundRobinRouter, PartitionRouter
View
168 nydus/db/routers/base.py
@@ -5,9 +5,13 @@
:copyright: (c) 2011 DISQUS.
:license: Apache License 2.0, see LICENSE for more details.
"""
+import time
+from binascii import crc32
+from collections import defaultdict
+from itertools import cycle
-__all__ = ('BaseRouter',)
+__all__ = ('BaseRouter', 'RoundRobinRouter', 'PartitionRouter')
class BaseRouter(object):
@@ -16,9 +20,163 @@ class BaseRouter(object):
"""
retryable = False
- def get_db(self, cluster, func, *args, **kwargs):
+ class UnableToSetupRouter(Exception):
+ pass
+
+ def __init__(self, *args, **kwargs):
+ self._ready = False
+
+ def get_dbs(self, cluster, attr, key=None, *args, **kwargs):
+ """
+ Perform setup and routing
+ Always return an iterable
+ Do not overload this method
+ """
+ if not self._ready:
+ if not self.setup_router(cluster, *args, **kwargs):
+ raise self.UnableToSetupRouter()
+
+ key = self._pre_routing(cluster, attr, key, *args, **kwargs)
+
+ if not key:
+ return cluster.hosts.keys()
+
+ try:
+ db_nums = self._route(cluster, attr, key, *args, **kwargs)
+ except Exception, e:
+ self._handle_exception(e)
+ db_nums = []
+
+ return self._post_routing(cluster, attr, key, db_nums, *args, **kwargs)
+
+ # Backwards compatibilty
+ get_db = get_dbs
+
+ def setup_router(self, cluster, *args, **kwargs):
+ """
+ Call method to perform any setup
+ """
+ self._ready = self._setup_router(cluster, *args, **kwargs)
+
+ return self._ready
+
+ def _setup_router(self, cluster, *args, **kwargs):
"""
- Return the first entry in the cluster
- The return value must be iterable
+ Perform any initialization for the router
+ Returns False if setup could not be completed
"""
- return range(len(cluster))
+ return True
+
+ def _pre_routing(self, cluster, attr, key, *args, **kwargs):
+ """
+ Perform any prerouting with this method and return the key
+ """
+ return key
+
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ """
+ Perform routing and return db_nums
+ """
+ return [cluster.hosts.keys()[0]]
+
+ def _post_routing(self, cluster, attr, key, db_nums, *args, **kwargs):
+ """
+ Perform any postrouting actions and return db_nums
+ """
+ return db_nums
+
+ def _handle_exception(self, e):
+ """
+ Handle/transform exceptions and return it
+ """
+ raise e
+
+
+class RoundRobinRouter(BaseRouter):
+ """
+ Basic retry router that performs round robin
+ """
+
+ # Raised if all hosts in the hash have been marked as down
+ class HostListExhausted(Exception):
+ pass
+
+ class InvalidDBNum(Exception):
+ pass
+
+ # If this router can be retried on if a particular db index it gave out did
+ # not work
+ retryable = True
+
+ # How many requests to serve in a situation when a host is down before
+ # the down hosts are retried
+ attempt_reconnect_threshold = 100000
+
+ # Retry a down connection after this timeout
+ retry_timeout = 30
+
+ def __init__(self, *args, **kwargs):
+ self._get_db_attempts = 0
+ self._down_connections = {}
+
+ super(RoundRobinRouter,self).__init__(*args, **kwargs)
+
+ @classmethod
+ def ensure_db_num(cls, db_num):
+ try:
+ return int(db_num)
+ except ValueError:
+ raise cls.InvalidDBNum()
+
+ def flush_down_connections(self):
+ self._get_db_attempts = 0
+ self._down_connections = {}
+
+ def mark_connection_down(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._down_connections[db_num] = time.time()
+
+ def mark_connection_up(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._down_connections.pop(db_num, None)
+
+ def _setup_router(self, cluster, *args, **kwargs):
+ self._hosts_cycler = cycle(cluster.hosts.keys())
+
+ return True
+
+ def _pre_routing(self, cluster, attr, key, *args, **kwargs):
+ self._get_db_attempts += 1
+
+ if self._get_db_attempts > self.attempt_reconnect_threshold:
+ self.flush_down_connections()
+
+ if 'retry_for' in kwargs:
+ self.mark_connection_down(kwargs['retry_for'])
+
+ return key
+
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ now = time.time()
+
+ for i in xrange(len(cluster)):
+ db_num = self._hosts_cycler.next()
+
+ marked_down_at = self._down_connections.get(db_num, False)
+
+ if not marked_down_at or (marked_down_at + self.retry_timeout <= now):
+ return [db_num]
+ else:
+ raise self.HostListExhausted()
+
+ def _post_routing(self, cluster, attr, key, db_nums, *args, **kwargs):
+ if db_nums:
+ self.mark_connection_up(db_nums[0])
+
+ return db_nums
+
+
+class PartitionRouter(BaseRouter):
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ return [crc32(str(key)) % len(cluster)]
+
View
99 nydus/db/routers/redis.py
@@ -6,103 +6,50 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from binascii import crc32
-from itertools import cycle
-
-from nydus.db.routers import BaseRouter
+from nydus.db.routers import RoundRobinRouter
from nydus.contrib.ketama import Ketama
-class PartitionRouter(BaseRouter):
- def get_db(self, cluster, func, key=None, *args, **kwargs):
- # Assume first argument is a key
- if not key:
- return range(len(cluster))
- return [crc32(str(key)) % len(cluster)]
-
-
-class RoundRobinRouter(BaseRouter):
-
- def _get_db__round_robin(self, cluster):
- c = cycle(range(len(cluster)))
- for x in c:
- yield x
-
- def get_db(self, cluster, *args, **kwargs):
- if not cluster:
- return []
- if not hasattr(self, 'cycler'):
- self.cycler = self._get_db__round_robin(cluster)
- return [self.cycler.next()]
-
-
-class ConsistentHashingRouter(BaseRouter):
+class ConsistentHashingRouter(RoundRobinRouter):
'''
Router that returns host number based on a consistent hashing algorithm.
The consistent hashing algorithm only works if a key argument is provided.
If a key is not provided, then all hosts are returned.
'''
- # Raised if all hosts in the hash have been marked as down
- class HostListExhaused(Exception):
- pass
-
- # If this router can be retried on if a particular db index it gave out did
- # not work
- retryable = True
-
- # How many requests to serve in a situation when a host is down before
- # the down hosts are retried
- attempt_reconnect_threshold = 100000
-
- def __init__(self):
- self._get_db_attempts = 0
- self._down_connections = set()
-
- # There is one instance of this class that lives inside the Cluster object
- def get_db(self, cluster, func, key=None, *args, **kwargs):
- self._setup_hash_and_connections(cluster, *args, **kwargs)
-
- if not cluster:
- return []
- elif not key:
- return range(len(cluster))
- else:
- return self._host_indexes_for(key, cluster)
+ def __init__(self, *args, **kwargs):
+ self._db_num_id_map = {}
+ super(ConsistentHashingRouter, self).__init__(*args, **kwargs)
def flush_down_connections(self):
- for connection in self._down_connections:
- self._hash.add_node(connection.identifier)
+ for db_num in self._down_connections:
+ self._hash.add_node(self._db_num_id_map[db_num])
- self._down_connections = set()
+ super(ConsistentHashingRouter, self).flush_down_connections()
- def _setup_hash_and_connections(self, cluster, *args, **kwargs):
- # Create the hash if it doesn't exist yet
- if not hasattr(self, '_hash'):
- strings = [h.identifier for (i, h) in cluster.hosts.items()]
- self._hash = Ketama(strings)
+ def mark_connection_down(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._hash.remove_node(self._db_num_id_map[db_num])
- self._handle_host_retries(cluster, retry_for=kwargs.get('retry_for'))
+ super(ConsistentHashingRouter, self).mark_connection_down(db_num)
- def _handle_host_retries(self, cluster, retry_for):
- self._get_db_attempts += 1
+ def mark_conenction_up(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._hash.add_node(self._db_num_id_map[db_num])
- if self._get_db_attempts > self.attempt_reconnect_threshold:
- self.flush_down_connections()
- self._get_db_attempts = 0
+ super(ConsistentHashingRouter, self).mark_connection_up(db_num)
- if retry_for is not None:
- self._mark_connection_as_down(cluster[retry_for])
+ def _setup_router(self, cluster, *args, **kwargs):
+ self._db_num_id_map = dict([(db_num, host.identifier) for db_num, host in cluster.hosts.iteritems()])
+ self._hash = Ketama(self._db_num_id_map.values())
- def _mark_connection_as_down(self, connection):
- self._hash.remove_node(connection.identifier)
- self._down_connections.add(connection)
+ return True
- def _host_indexes_for(self, key, cluster):
+ def _route(self, cluster, attr, key, *args, **kwargs):
found = self._hash.get_node(key)
if not found and len(self._down_connections) > 0:
- raise self.HostListExhaused
+ raise self.HostListExhausted()
- return [i for (i, h) in cluster.hosts.items()
+ return [i for i, h in cluster.hosts.iteritems()
if h.identifier == found]
View
7 tests/nydus/contrib/django/tests.py
@@ -68,13 +68,6 @@ def test_proxy(self):
cursor = self.db.execute('SELECT 1')
self.assertEquals(cursor.fetchone(), (1,))
- def test_with_cluster(self):
- p = Cluster(
- hosts={0: self.db},
- )
- cursor = p.execute('SELECT 1')
- self.assertEquals(cursor.fetchone(), (1,))
-
def test_provides_identififer(self):
self.assertEqual(
"django.db.backends.sqlite3NAME=:memory: PORT=None HOST=None USER=None TEST_NAME=None PASSWORD=None OPTIONS={}",
View
7 tests/nydus/db/backends/redis/tests.py
@@ -32,7 +32,7 @@ def test_provides_identifier(self):
def test_pipelined_map(self):
redis = create_cluster({
'engine': 'nydus.db.backends.redis.Redis',
- 'router': 'nydus.db.routers.redis.PartitionRouter',
+ 'router': 'nydus.db.routers.PartitionRouter',
'hosts': {
0: {'db': 5},
1: {'db': 6},
@@ -59,12 +59,13 @@ def test_client_instantiates_with_kwargs(self, RedisClient):
def test_map_does_pipeline(self, RedisClient):
redis = create_cluster({
'engine': 'nydus.db.backends.redis.Redis',
- 'router': 'nydus.db.routers.redis.PartitionRouter',
+ 'router': 'nydus.db.routers.PartitionRouter',
'hosts': {
0: {'db': 0},
1: {'db': 1},
}
})
+
with redis.map() as conn:
conn.set('a', 0)
conn.set('d', 1)
@@ -86,7 +87,7 @@ def test_map_does_pipeline(self, RedisClient):
def test_map_only_runs_on_required_nodes(self, RedisClient):
redis = create_cluster({
'engine': 'nydus.db.backends.redis.Redis',
- 'router': 'nydus.db.routers.redis.PartitionRouter',
+ 'router': 'nydus.db.routers.PartitionRouter',
'hosts': {
0: {'db': 0},
1: {'db': 1},
View
4 tests/nydus/db/backends/thoonk/tests.py
@@ -62,7 +62,7 @@ def test_job_with_ConsistentHashingRouter(self):
self.assertTrue(jid_found)
def test_job_with_RoundRobinRouter(self):
- pubsub = self.get_cluster('nydus.db.routers.redis.RoundRobinRouter')
+ pubsub = self.get_cluster('nydus.db.routers.RoundRobinRouter')
jobs = {}
size = 20
@@ -89,7 +89,7 @@ def test_job_with_RoundRobinRouter(self):
self.assertEqual(len(jobs), 5)
for x in range(len(pubsub)):
- ps = pubsub.get_conn()
+ ps = pubsub.get_conn('testjob')
jps = ps.job('testjob')
self.assertEqual(jps.get_ids(), [])
View
27 tests/nydus/db/connections/tests.py
@@ -26,8 +26,7 @@ def foo(self, *args, **kwargs):
class DummyRouter(BaseRouter):
- def get_db(self, cluster, func, key=None, *args, **kwargs):
- # Assume first argument is a key
+ def get_dbs(self, cluster, attr, key=None, *args, **kwargs):
if key == 'foo':
return [1]
return [0]
@@ -70,7 +69,7 @@ def test_with_router(self):
c2 = DummyConnection(num=1, resp='bar')
# test dummy router
- r = DummyRouter()
+ r = DummyRouter
p = Cluster(
hosts={0: c, 1: c2},
router=r,
@@ -83,14 +82,14 @@ def test_with_router(self):
hosts={0: c, 1: c2},
)
self.assertEquals(p.foo(), ['foo', 'bar'])
- self.assertEquals(p.foo('foo'), ['foo', 'bar'])
+ self.assertEquals(p.foo('foo'), 'foo')
def test_get_conn(self):
c = DummyConnection(alias='foo', num=0, resp='foo')
c2 = DummyConnection(alias='foo', num=1, resp='bar')
# test dummy router
- r = DummyRouter()
+ r = DummyRouter
p = Cluster(
hosts={0: c, 1: c2},
router=r,
@@ -103,14 +102,14 @@ def test_get_conn(self):
hosts={0: c, 1: c2},
)
self.assertEquals(p.get_conn(), [c, c2])
- self.assertEquals(p.get_conn('foo'), [c, c2])
+ self.assertEquals(p.get_conn('foo'), c)
def test_map(self):
c = DummyConnection(num=0, resp='foo')
c2 = DummyConnection(num=1, resp='bar')
# test dummy router
- r = DummyRouter()
+ r = DummyRouter
p = Cluster(
hosts={0: c, 1: c2},
router=r,
@@ -135,7 +134,7 @@ def test_map(self):
self.assertEquals(bar, None)
self.assertEquals(foo, ['foo', 'bar'])
- self.assertEquals(bar, ['foo', 'bar'])
+ self.assertEquals(bar, 'foo')
class FlakeyConnection(DummyConnection):
@@ -158,7 +157,7 @@ def __init__(self):
self.key_args_seen = []
super(RetryableRouter, self).__init__()
- def get_db(self, cluster, func, key=None, *args, **kwargs):
+ def get_dbs(self, cluster, func, key=None, *args, **kwargs):
self.kwargs_seen.append(kwargs)
self.key_args_seen.append(key)
return [0]
@@ -171,7 +170,7 @@ def __init__(self):
self.returned = False
super(InconsistentRouter, self).__init__()
- def get_db(self, cluster, func, key=None, *args, **kwargs):
+ def get_dbs(self, cluster, func, key=None, *args, **kwargs):
if self.returned:
return range(5)
else:
@@ -210,12 +209,8 @@ def test_retry_router_when_receives_error(self):
def test_protection_from_infinate_loops(self):
cluster = self.build_cluster(connection=ScumbagConnection)
- self.assertRaises(Exception, cluster.foo)
-
- def test_retryable_router_returning_multiple_dbs_raises_ecxeption(self):
- cluster = self.build_cluster(router=InconsistentRouter, connection=ScumbagConnection)
- self.assertRaisesRegexp(Exception, 'returned multiple DBs',
- cluster.foo)
+ with self.assertRaises(Exception):
+ cluster.foo()
class EventualCommandTest(BaseTest):
View
254 tests/nydus/db/routers/tests.py
@@ -1,10 +1,17 @@
from __future__ import absolute_import
+import time
+
+from collections import Iterable
+from inspect import getargspec
+
+from mock import patch
+
from tests import BaseTest
from nydus.db.base import Cluster
from nydus.db.backends import BaseConnection
+from nydus.db.routers import BaseRouter, RoundRobinRouter, PartitionRouter
from nydus.db.routers.redis import ConsistentHashingRouter
-from nydus.db.routers.redis import RoundRobinRouter
class DummyConnection(BaseConnection):
@@ -18,119 +25,234 @@ def identifier(self):
return "%s:%s" % (self.host, self.i)
-class RoundRobinRouterTest(BaseTest):
+class BaseRouterTest(BaseTest):
+ Router = BaseRouter
+ class TestException(Exception): pass
def setUp(self):
- self.router = RoundRobinRouter()
+ self.router = self.Router()
self.hosts = dict((i, DummyConnection(i)) for i in range(5))
- self.cluster = Cluster(router=self.router, hosts=self.hosts)
+ self.cluster = Cluster(router=self.Router, hosts=self.hosts)
- def get_db(self, *args, **kwargs):
+ def get_dbs(self, *args, **kwargs):
kwargs.setdefault('cluster', self.cluster)
- return self.router.get_db(*args, **kwargs)
+ return self.router.get_dbs(*args, **kwargs)
+ def test_not_ready(self):
+ self.assertTrue(not self.router._ready)
-class ConsistentHashingRouterTest(BaseTest):
+ def test_get_dbs_iterable(self):
+ db_nums = self.get_dbs(attr='test', key='foo')
+ self.assertIsInstance(db_nums, Iterable)
- def setUp(self):
- self.router = ConsistentHashingRouter()
- self.hosts = dict((i, DummyConnection(i)) for i in range(5))
- self.cluster = Cluster(router=self.router, hosts=self.hosts)
+ def test_get_dbs_unabletosetuproute(self):
+ with patch.object(self.router, '_setup_router', return_value=False):
+ with self.assertRaises(BaseRouter.UnableToSetupRouter):
+ self.get_dbs(attr='test', key='foo')
+
+ def test_setup_router_returns_true(self):
+ self.assertTrue(self.router.setup_router(self.cluster))
- def get_db(self, **kwargs):
- kwargs.setdefault('cluster', self.cluster)
- return self.router.get_db(func='info', **kwargs)
+ def test_offers_router_interface(self):
+ self.assertTrue(callable(self.router.get_dbs))
+ dbargs, _, _, dbdefaults = getargspec(self.router.get_dbs)
+ self.assertTrue(set(dbargs) >= set(['self', 'cluster', 'attr', 'key']))
+ self.assertIsNone(dbdefaults[0])
+
+ self.assertTrue(callable(self.router.setup_router))
+ setupargs, _, _, setupdefaults = getargspec(self.router.setup_router)
+ self.assertTrue(set(setupargs) >= set(['self', 'cluster']))
+ self.assertIsNone(setupdefaults)
+
+ def test_returns_whole_cluster_without_key(self):
+ self.assertEquals(self.hosts.keys(), self.get_dbs(attr='test'))
+
+ def test_returns_sequence_with_one_item_when_given_key(self):
+ self.assertEqual(len(self.get_dbs(attr='test', key='foo')), 1)
+
+ def test_get_dbs_handles_exception(self):
+ with patch.object(self.router, '_route') as _route:
+ with patch.object(self.router, '_handle_exception') as _handle_exception:
+ _route.side_effect = self.TestException()
+
+ self.get_dbs(attr='test', key='foo')
+ self.assertTrue(_handle_exception.called)
-class RoundRobinTest(BaseTest):
+
+class BaseBaseRouterTest(BaseRouterTest):
+ def test__setup_router_returns_true(self):
+ self.assertTrue(self.router._setup_router(self.cluster))
+
+ def test__pre_routing_returns_key(self):
+ key = 'foo'
+
+ self.assertEqual(key, self.router._pre_routing(self.cluster, 'foo', key))
+
+ def test__route_returns_first_db_num(self):
+ self.assertEqual(self.cluster.hosts.keys()[0], self.router._route(self.cluster, 'test', 'foo')[0])
+
+ def test__post_routing_returns_db_nums(self):
+ db_nums = self.hosts.keys()
+
+ self.assertEqual(db_nums, self.router._post_routing(self.cluster, 'test', 'foo', db_nums))
+
+ def test__handle_exception_raises_same_exception(self):
+ e = self.TestException()
+
+ with self.assertRaises(self.TestException):
+ self.router._handle_exception(e)
+
+
+class BaseRoundRobinRouterTest(BaseRouterTest):
+ Router = RoundRobinRouter
def setUp(self):
- self.router = RoundRobinRouter()
- self.hosts = dict((i, DummyConnection(i)) for i in range(5))
- self.cluster = Cluster(router=self.router, hosts=self.hosts)
+ super(BaseRoundRobinRouterTest, self).setUp()
+ assert self.router._setup_router(self.cluster)
- def get_db(self, *args, **kwargs):
- kwargs.setdefault('cluster', self.cluster)
- return self.router.get_db(*args, **kwargs)
+ def test_ensure_db_num(self):
+ db_num = 0
+ s_db_num = str(db_num)
- def test_cluster_of_zero_returns_zero(self):
- self.cluster.hosts = dict()
- self.assertEquals([], self.get_db())
+ self.assertEqual(self.router.ensure_db_num(db_num), db_num)
+ self.assertEqual(self.router.ensure_db_num(s_db_num), db_num)
- def test_cluster_of_one_returns_one(self):
- self.cluster.hosts = {0: DummyConnection('foo')}
- self.assertEquals([0], self.get_db())
+ def test_esnure_db_num_raises(self):
+ with self.assertRaises(RoundRobinRouter.InvalidDBNum):
+ self.router.ensure_db_num('a')
- def test_multi_node_cluster_returns_correct_host(self):
- self.cluster.hosts = {0: DummyConnection('foo'), 1: DummyConnection('bar')}
- self.assertEquals([[0], [1], [0], [1]], [self.get_db(), self.get_db(), self.get_db(), self.get_db(), ])
+ def test_flush_down_connections(self):
+ self.router._get_db_attempts = 9001
+ self._down_connections = {0: time.time()}
+ self.router.flush_down_connections()
-class InterfaceTest(ConsistentHashingRouterTest):
+ self.assertEqual(self.router._get_db_attempts, 0)
+ self.assertEqual(self.router._down_connections, {})
- def test_offers_router_interface(self):
- self.assertTrue(callable(self.router.get_db))
+ def test_mark_connection_down(self):
+ db_num = 0
- def test_get_db_returns_itereable(self):
- iter(self.get_db())
+ self.router.mark_connection_down(db_num)
- def test_returns_whole_cluster_without_key(self):
- self.assertEquals(range(5), self.get_db())
+ self.assertAlmostEqual(self.router._down_connections[db_num], time.time(), delta=10)
- def test_returns_sequence_with_one_item_when_given_key(self):
- self.assert_(len(self.get_db(key='foo')) is 1)
+ def test_mark_connection_up(self):
+ db_num = 0
+
+ self.router.mark_connection_down(db_num)
+
+ self.assertIn(db_num, self.router._down_connections)
+
+ self.router.mark_connection_up(db_num)
+
+ self.assertNotIn(db_num, self.router._down_connections)
+
+ def test__pre_routing_updates__get_db_attempts(self):
+ self.router._pre_routing(self.cluster, 'test', 'foo')
+
+ self.assertEqual(self.router._get_db_attempts, 1)
+
+ @patch('nydus.db.routers.RoundRobinRouter.flush_down_connections')
+ def test__pre_routing_flush_down_connections(self, _flush_down_connections):
+ self.router._get_db_attempts = RoundRobinRouter.attempt_reconnect_threshold + 1
+
+ self.router._pre_routing(self.cluster, 'test', 'foo')
+
+ self.assertTrue(_flush_down_connections.called)
+
+ @patch('nydus.db.routers.RoundRobinRouter.mark_connection_down')
+ def test__pre_routing_retry_for(self, _mark_connection_down):
+ db_num = 0
+
+ self.router._pre_routing(self.cluster, 'test', 'foo', retry_for=db_num)
+
+ _mark_connection_down.assert_called_with(db_num)
+ @patch('nydus.db.routers.RoundRobinRouter.mark_connection_up')
+ def test__post_routing_mark_connection_up(self, _mark_connection_up):
+ db_nums = [0]
-class HashingTest(ConsistentHashingRouterTest):
+ self.assertEqual(self.router._post_routing(self.cluster, 'test', 'foo', db_nums), db_nums)
+ _mark_connection_up.assert_called_with(db_nums[0])
- def get_db(self, **kwargs):
- kwargs['key'] = 'foo'
- return super(HashingTest, self).get_db(**kwargs)
- def test_cluster_of_zero_returns_zero(self):
- self.cluster.hosts = dict()
- self.assertEquals([], self.get_db())
+class RoundRobinRouterTest(BaseRoundRobinRouterTest):
+ def test__setup_router(self):
+ self.assertTrue(self.router._setup_router(self.cluster))
+ self.assertIsInstance(self.router._hosts_cycler, Iterable)
- def test_cluster_of_one_returns_one(self):
- self.cluster.hosts = dict(only_key=DummyConnection('foo'))
- self.assertEquals(['only_key'], self.get_db())
+ def test__route_cycles_through_keys(self):
+ db_nums = self.hosts.keys() * 2
+ results = [self.router._route(self.cluster, 'test', 'foo')[0] for _ in db_nums]
- def test_multi_node_cluster_returns_correct_host(self):
- self.assertEquals([2], self.get_db())
+ self.assertEqual(results, db_nums)
+ def test__route_retry(self):
+ self.router.retry_timeout = 0
-class RetryableTest(HashingTest):
+ db_num = 0
- def test_attempt_reconnect_threshold_is_set(self):
- self.assertEqual(self.router.attempt_reconnect_threshold, 100000)
+ self.router.mark_connection_down(db_num)
+
+ db_nums = self.router._route(self.cluster, 'test', 'foo')
+
+ self.assertEqual(db_nums, [db_num])
+
+ def test__route_skip_down(self):
+ db_num = 0
+
+ self.router.mark_connection_down(db_num)
+
+ db_nums = self.router._route(self.cluster, 'test', 'foo')
+
+ self.assertNotEqual(db_nums, [db_num])
+ self.assertEqual(db_nums, [db_num+1])
+
+ def test__route_hostlistexhausted(self):
+ [self.router.mark_connection_down(db_num) for db_num in self.hosts.keys()]
+
+ with self.assertRaises(RoundRobinRouter.HostListExhausted):
+ self.router._route(self.cluster, 'test', 'foo')
+
+
+
+class ConsistentHashingRouterTest(BaseRoundRobinRouterTest):
+ Router = ConsistentHashingRouter
+
+ def get_dbs(self, *args, **kwargs):
+ kwargs['attr'] = 'test'
+ return super(ConsistentHashingRouterTest, self).get_dbs(*args, **kwargs)
def test_retry_gives_next_host_if_primary_is_offline(self):
- self.assertEquals([2], self.get_db())
- self.assertEquals([4], self.get_db(retry_for=2))
+ self.assertEquals([2], self.get_dbs(key='foo'))
+ self.assertEquals([4], self.get_dbs(key='foo', retry_for=2))
def test_retry_host_change_is_sticky(self):
- self.assertEquals([2], self.get_db())
- self.assertEquals([4], self.get_db(retry_for=2))
+ self.assertEquals([2], self.get_dbs(key='foo'))
+ self.assertEquals([4], self.get_dbs(key='foo', retry_for=2))
- self.assertEquals([4], self.get_db())
+ self.assertEquals([4], self.get_dbs(key='foo'))
def test_adds_back_down_host_once_attempt_reconnect_threshold_is_passed(self):
ConsistentHashingRouter.attempt_reconnect_threshold = 3
- self.assertEquals([2], self.get_db())
- self.assertEquals([4], self.get_db(retry_for=2))
- self.assertEquals([4], self.get_db())
+ self.assertEquals([2], self.get_dbs(key='foo'))
+ self.assertEquals([4], self.get_dbs(key='foo', retry_for=2))
+ self.assertEquals([4], self.get_dbs(key='foo'))
# Router should add host 1 back to the pool now
- self.assertEquals([2], self.get_db())
+ self.assertEquals([2], self.get_dbs(key='foo'))
ConsistentHashingRouter.attempt_reconnect_threshold = 100000
def test_raises_host_list_exhaused_if_no_host_can_be_found(self):
# Kill the first 4
- [self.get_db(retry_for=i) for i in range(4)]
+ [self.get_dbs(retry_for=i) for i in range(4)]
# And the 5th should raise an error
self.assertRaises(
- ConsistentHashingRouter.HostListExhaused,
- self.get_db, **dict(retry_for=4))
+ ConsistentHashingRouter.HostListExhausted,
+ self.get_dbs, **dict(key='foo', retry_for=4))
+
Please sign in to comment.
Something went wrong with that request. Please try again.