Skip to content

Commit

Permalink
Handle Redis connection errors in result consumer (#5921)
Browse files Browse the repository at this point in the history
* Handle Redis connection errors in result consumer

* Closes #5919.

* Use context manager for Redis conusmer reconnect

* Log error when result backend reconnection fails
  • Loading branch information
michamos committed Feb 28, 2020
1 parent fe0c33c commit 6ccdc7b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 8 deletions.
50 changes: 43 additions & 7 deletions celery/backends/redis.py
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import, unicode_literals

import time
from contextlib import contextmanager
from functools import partial
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED

Expand Down Expand Up @@ -78,6 +79,11 @@

E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.'

E_RETRY_LIMIT_EXCEEDED = """
Retry limit exceeded while trying to reconnect to the Celery redis result \
store backend. The Celery application must be restarted.
"""

logger = get_logger(__name__)


Expand All @@ -88,6 +94,8 @@ def __init__(self, *args, **kwargs):
super(ResultConsumer, self).__init__(*args, **kwargs)
self._get_key_for_task = self.backend.get_key_for_task
self._decode_result = self.backend.decode_result
self._ensure = self.backend.ensure
self._connection_errors = self.backend.connection_errors
self.subscribed_to = set()

def on_after_fork(self):
Expand All @@ -99,6 +107,31 @@ def on_after_fork(self):
logger.warning(text_t(e))
super(ResultConsumer, self).on_after_fork()

def _reconnect_pubsub(self):
self._pubsub = None
self.backend.client.connection_pool.reset()
# task state might have changed when the connection was down so we
# retrieve meta for all subscribed tasks before going into pubsub mode
metas = self.backend.client.mget(self.subscribed_to)
metas = [meta for meta in metas if meta]
for meta in metas:
self.on_state_change(self._decode_result(meta), None)
self._pubsub = self.backend.client.pubsub(
ignore_subscribe_messages=True,
)
self._pubsub.subscribe(*self.subscribed_to)

@contextmanager
def reconnect_on_error(self):
try:
yield
except self._connection_errors:
try:
self._ensure(self._reconnect_pubsub, ())
except self._connection_errors:
logger.critical(E_RETRY_LIMIT_EXCEEDED)
raise

def _maybe_cancel_ready_task(self, meta):
if meta['status'] in states.READY_STATES:
self.cancel_for(meta['task_id'])
Expand All @@ -124,9 +157,10 @@ def stop(self):

def drain_events(self, timeout=None):
if self._pubsub:
message = self._pubsub.get_message(timeout=timeout)
if message and message['type'] == 'message':
self.on_state_change(self._decode_result(message['data']), message)
with self.reconnect_on_error():
message = self._pubsub.get_message(timeout=timeout)
if message and message['type'] == 'message':
self.on_state_change(self._decode_result(message['data']), message)
elif timeout:
time.sleep(timeout)

Expand All @@ -139,13 +173,15 @@ def _consume_from(self, task_id):
key = self._get_key_for_task(task_id)
if key not in self.subscribed_to:
self.subscribed_to.add(key)
self._pubsub.subscribe(key)
with self.reconnect_on_error():
self._pubsub.subscribe(key)

def cancel_for(self, task_id):
key = self._get_key_for_task(task_id)
self.subscribed_to.discard(key)
if self._pubsub:
key = self._get_key_for_task(task_id)
self.subscribed_to.discard(key)
self._pubsub.unsubscribe(key)
with self.reconnect_on_error():
self._pubsub.unsubscribe(key)


class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin):
Expand Down
57 changes: 56 additions & 1 deletion t/unit/backends/test_redis.py
@@ -1,5 +1,6 @@
from __future__ import absolute_import, unicode_literals

import json
import random
import ssl
from contextlib import contextmanager
Expand All @@ -26,6 +27,10 @@ def on_first_call(*args, **kwargs):
mock.return_value, = retval


class ConnectionError(Exception):
pass


class Connection(object):
connected = True

Expand Down Expand Up @@ -55,9 +60,27 @@ def execute(self):
return [step(*a, **kw) for step, a, kw in self.steps]


class PubSub(mock.MockCallbacks):
def __init__(self, ignore_subscribe_messages=False):
self._subscribed_to = set()

def close(self):
self._subscribed_to = set()

def subscribe(self, *args):
self._subscribed_to.update(args)

def unsubscribe(self, *args):
self._subscribed_to.difference_update(args)

def get_message(self, timeout=None):
pass


class Redis(mock.MockCallbacks):
Connection = Connection
Pipeline = Pipeline
pubsub = PubSub

def __init__(self, host=None, port=None, db=None, password=None, **kw):
self.host = host
Expand All @@ -71,6 +94,9 @@ def __init__(self, host=None, port=None, db=None, password=None, **kw):
def get(self, key):
return self.keyspace.get(key)

def mget(self, keys):
return [self.get(key) for key in keys]

def setex(self, key, expires, value):
self.set(key, value)
self.expire(key, expires)
Expand Down Expand Up @@ -144,7 +170,9 @@ class _RedisBackend(RedisBackend):
return _RedisBackend(app=self.app)

def get_consumer(self):
return self.get_backend().result_consumer
consumer = self.get_backend().result_consumer
consumer._connection_errors = (ConnectionError,)
return consumer

@patch('celery.backends.asynchronous.BaseResultConsumer.on_after_fork')
def test_on_after_fork(self, parent_method):
Expand Down Expand Up @@ -194,6 +222,33 @@ def test_drain_events_before_start(self):
# drain_events shouldn't crash when called before start
consumer.drain_events(0.001)

def test_consume_from_connection_error(self):
consumer = self.get_consumer()
consumer.start('initial')
consumer._pubsub.subscribe.side_effect = (ConnectionError(), None)
consumer.consume_from('some-task')
assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial', b'celery-task-meta-some-task'}

def test_cancel_for_connection_error(self):
consumer = self.get_consumer()
consumer.start('initial')
consumer._pubsub.unsubscribe.side_effect = ConnectionError()
consumer.consume_from('some-task')
consumer.cancel_for('some-task')
assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'}

@patch('celery.backends.redis.ResultConsumer.cancel_for')
@patch('celery.backends.asynchronous.BaseResultConsumer.on_state_change')
def test_drain_events_connection_error(self, parent_on_state_change, cancel_for):
meta = {'task_id': 'initial', 'status': states.SUCCESS}
consumer = self.get_consumer()
consumer.start('initial')
consumer.backend.set(b'celery-task-meta-initial', json.dumps(meta))
consumer._pubsub.get_message.side_effect = ConnectionError()
consumer.drain_events()
parent_on_state_change.assert_called_with(meta, None)
assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'}


class test_RedisBackend:
def get_backend(self):
Expand Down

0 comments on commit 6ccdc7b

Please sign in to comment.