Permalink
Browse files

Merge pull request #5 from disqus/robust-roundrobin

Refactor of routers
  • Loading branch information...
2 parents b0edafc + de2c34c commit 825f52e15aa451636fca76dfdc7f244c5704b45b @dctrwatson dctrwatson committed Apr 10, 2012
View
@@ -25,24 +25,24 @@ def create_cluster(settings):
"""
# Pull in our client
if isinstance(settings['engine'], basestring):
- conn = import_string(settings['engine'])
+ Conn = import_string(settings['engine'])
else:
- conn = settings['engine']
+ Conn = settings['engine']
# Pull in our router
router = settings.get('router')
if isinstance(router, basestring):
- router = import_string(router)()
+ router = import_string(router)
elif router:
- router = router()
+ router = router
else:
- router = BaseRouter()
+ router = BaseRouter
# Build the connection cluster
return Cluster(
router=router,
hosts=dict(
- (conn_number, conn(num=conn_number, **host_settings))
+ (conn_number, Conn(num=conn_number, **host_settings))
for conn_number, host_settings
in settings['hosts'].iteritems()
),
@@ -53,10 +53,12 @@ class Cluster(object):
"""
Holds a cluster of connections.
"""
+ class MaxRetriesExceededError(Exception):
+ pass
- def __init__(self, hosts, router=None, max_connection_retries=20):
+ def __init__(self, hosts, router=BaseRouter, max_connection_retries=20):
self.hosts = hosts
- self.router = router
+ self.router = router()
self.max_connection_retries = max_connection_retries
def __len__(self):
@@ -76,22 +78,28 @@ def __iter__(self):
yield name
def _execute(self, attr, args, kwargs):
- db_nums = self._db_nums_for(*args, **kwargs)
+ connections = self._connections_for(attr, *args, **kwargs)
+
+ results = []
+ for conn in connections:
+ for retry in xrange(self.max_connection_retries):
+ try:
+ results.append(getattr(conn, attr)(*args, **kwargs))
+ except tuple(conn.retryable_exceptions), e:
+ if not self.router.retryable:
+ raise e
+ elif retry == self.max_connection_retries - 1:
+ raise self.MaxRetriesExceededError(e)
+ else:
+ conn = self._connections_for(attr, retry_for=conn.num, *args, **kwargs)[0]
+ else:
+ break
- if self.router and len(db_nums) is 1 and self.router.retryable:
- # The router supports retryable commands, so we want to run a
- # separate algorithm for how we get connections to run commands on
- # and then possibly retry
- return self._retryable_execute(db_nums, attr, *args, **kwargs)
+ # If we only had one db to query, we simply return that res
+ if len(results) == 1:
+ return results[0]
else:
- connections = self._connections_for(*args, **kwargs)
- results = [getattr(conn, attr)(*args, **kwargs) for conn in connections]
-
- # If we only had one db to query, we simply return that res
- if len(results) == 1:
- return results[0]
- else:
- return results
+ return results
def disconnect(self):
"""Disconnects all connections in cluster"""
@@ -106,7 +114,7 @@ def get_conn(self, *args, **kwargs):
during all steps of the process. An example of this would be
Redis pipelines.
"""
- connections = self._connections_for(*args, **kwargs)
+ connections = self._connections_for('get_conn', *args, **kwargs)
if len(connections) is 1:
return connections[0]
@@ -116,34 +124,8 @@ def get_conn(self, *args, **kwargs):
def map(self, workers=None):
return DistributedContextManager(self, workers)
- def _retryable_execute(self, db_nums, attr, *args, **kwargs):
- retries = 0
-
- while retries <= self.max_connection_retries:
- if len(db_nums) > 1:
- raise Exception('Retryable router returned multiple DBs')
- else:
- connection = self[db_nums[0]]
-
- try:
- return getattr(connection, attr)(*args, **kwargs)
- except tuple(connection.retryable_exceptions):
- # We had a failure, so get a new db_num and try again, noting
- # the DB number that just failed, so the backend can mark it as
- # down
- db_nums = self._db_nums_for(retry_for=db_nums[0], *args, **kwargs)
- retries += 1
- else:
- raise Exception('Maximum amount of connection retries exceeded')
-
- def _db_nums_for(self, *args, **kwargs):
- if self.router:
- return self.router.get_db(self, 'get_conn', *args, **kwargs)
- else:
- return range(len(self))
-
- def _connections_for(self, *args, **kwargs):
- return [self[n] for n in self._db_nums_for(*args, **kwargs)]
+ def _connections_for(self, attr, *args, **kwargs):
+ return [self[n] for n in self.router.get_dbs(self, attr, *args, **kwargs)]
class CallProxy(object):
@@ -325,9 +307,9 @@ def _execute(self):
command_map[cmd_ident] = command
if self._cluster.router:
- db_nums = self._cluster.router.get_db(self._cluster, command._attr, *command._args, **command._kwargs)
+ db_nums = self._cluster.router.get_dbs(self._cluster, command._attr, *command._args, **command._kwargs)
else:
- db_nums = range(len(self._cluster))
+ db_nums = self._cluster.keys()
# The number of commands is based on the total number of executable commands
num_commands += len(db_nums)
@@ -6,4 +6,4 @@
:license: Apache License 2.0, see LICENSE for more details.
"""
-from .base import BaseRouter
+from .base import BaseRouter, RoundRobinRouter, PartitionRouter
View
@@ -5,9 +5,13 @@
:copyright: (c) 2011 DISQUS.
:license: Apache License 2.0, see LICENSE for more details.
"""
+import time
+from binascii import crc32
+from collections import defaultdict
+from itertools import cycle
-__all__ = ('BaseRouter',)
+__all__ = ('BaseRouter', 'RoundRobinRouter', 'PartitionRouter')
class BaseRouter(object):
@@ -16,9 +20,163 @@ class BaseRouter(object):
"""
retryable = False
- def get_db(self, cluster, func, *args, **kwargs):
+ class UnableToSetupRouter(Exception):
+ pass
+
+ def __init__(self, *args, **kwargs):
+ self._ready = False
+
+ def get_dbs(self, cluster, attr, key=None, *args, **kwargs):
+ """
+ Perform setup and routing
+ Always return an iterable
+ Do not overload this method
+ """
+ if not self._ready:
+ if not self.setup_router(cluster, *args, **kwargs):
+ raise self.UnableToSetupRouter()
+
+ key = self._pre_routing(cluster, attr, key, *args, **kwargs)
+
+ if not key:
+ return cluster.hosts.keys()
+
+ try:
+ db_nums = self._route(cluster, attr, key, *args, **kwargs)
+ except Exception, e:
+ self._handle_exception(e)
+ db_nums = []
+
+ return self._post_routing(cluster, attr, key, db_nums, *args, **kwargs)
+
+ # Backwards compatibilty
+ get_db = get_dbs
+
+ def setup_router(self, cluster, *args, **kwargs):
+ """
+ Call method to perform any setup
+ """
+ self._ready = self._setup_router(cluster, *args, **kwargs)
+
+ return self._ready
+
+ def _setup_router(self, cluster, *args, **kwargs):
"""
- Return the first entry in the cluster
- The return value must be iterable
+ Perform any initialization for the router
+ Returns False if setup could not be completed
"""
- return range(len(cluster))
+ return True
+
+ def _pre_routing(self, cluster, attr, key, *args, **kwargs):
+ """
+ Perform any prerouting with this method and return the key
+ """
+ return key
+
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ """
+ Perform routing and return db_nums
+ """
+ return [cluster.hosts.keys()[0]]
+
+ def _post_routing(self, cluster, attr, key, db_nums, *args, **kwargs):
+ """
+ Perform any postrouting actions and return db_nums
+ """
+ return db_nums
+
+ def _handle_exception(self, e):
+ """
+ Handle/transform exceptions and return it
+ """
+ raise e
+
+
+class RoundRobinRouter(BaseRouter):
+ """
+ Basic retry router that performs round robin
+ """
+
+ # Raised if all hosts in the hash have been marked as down
+ class HostListExhausted(Exception):
+ pass
+
+ class InvalidDBNum(Exception):
+ pass
+
+ # If this router can be retried on if a particular db index it gave out did
+ # not work
+ retryable = True
+
+ # How many requests to serve in a situation when a host is down before
+ # the down hosts are retried
+ attempt_reconnect_threshold = 100000
+
+ # Retry a down connection after this timeout
+ retry_timeout = 30
+
+ def __init__(self, *args, **kwargs):
+ self._get_db_attempts = 0
+ self._down_connections = {}
+
+ super(RoundRobinRouter,self).__init__(*args, **kwargs)
+
+ @classmethod
+ def ensure_db_num(cls, db_num):
+ try:
+ return int(db_num)
+ except ValueError:
+ raise cls.InvalidDBNum()
+
+ def flush_down_connections(self):
+ self._get_db_attempts = 0
+ self._down_connections = {}
+
+ def mark_connection_down(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._down_connections[db_num] = time.time()
+
+ def mark_connection_up(self, db_num):
+ db_num = self.ensure_db_num(db_num)
+ self._down_connections.pop(db_num, None)
+
+ def _setup_router(self, cluster, *args, **kwargs):
+ self._hosts_cycler = cycle(cluster.hosts.keys())
+
+ return True
+
+ def _pre_routing(self, cluster, attr, key, *args, **kwargs):
+ self._get_db_attempts += 1
+
+ if self._get_db_attempts > self.attempt_reconnect_threshold:
+ self.flush_down_connections()
+
+ if 'retry_for' in kwargs:
+ self.mark_connection_down(kwargs['retry_for'])
+
+ return key
+
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ now = time.time()
+
+ for i in xrange(len(cluster)):
+ db_num = self._hosts_cycler.next()
+
+ marked_down_at = self._down_connections.get(db_num, False)
+
+ if not marked_down_at or (marked_down_at + self.retry_timeout <= now):
+ return [db_num]
+ else:
+ raise self.HostListExhausted()
+
+ def _post_routing(self, cluster, attr, key, db_nums, *args, **kwargs):
+ if db_nums:
+ self.mark_connection_up(db_nums[0])
+
+ return db_nums
+
+
+class PartitionRouter(BaseRouter):
+ def _route(self, cluster, attr, key, *args, **kwargs):
+ return [crc32(str(key)) % len(cluster)]
+
Oops, something went wrong.

0 comments on commit 825f52e

Please sign in to comment.