| @@ -0,0 +1,299 @@ | ||
| """Async I/O backend support utilities.""" | ||
| from __future__ import absolute_import, unicode_literals | ||
| import socket | ||
| import threading | ||
| from collections import deque | ||
| from time import sleep | ||
| from weakref import WeakKeyDictionary | ||
| from kombu.utils.compat import detect_environment | ||
| from kombu.utils.objects import cached_property | ||
| from celery import states | ||
| from celery.exceptions import TimeoutError | ||
| from celery.five import Empty, monotonic | ||
| from celery.utils.threads import THREAD_TIMEOUT_MAX | ||
| __all__ = ( | ||
| 'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer', | ||
| 'register_drainer', | ||
| ) | ||
| drainers = {} | ||
| def register_drainer(name): | ||
| """Decorator used to register a new result drainer type.""" | ||
| def _inner(cls): | ||
| drainers[name] = cls | ||
| return cls | ||
| return _inner | ||
| @register_drainer('default') | ||
| class Drainer(object): | ||
| """Result draining service.""" | ||
| def __init__(self, result_consumer): | ||
| self.result_consumer = result_consumer | ||
| def start(self): | ||
| pass | ||
| def stop(self): | ||
| pass | ||
| def drain_events_until(self, p, timeout=None, on_interval=None, wait=None): | ||
| wait = wait or self.result_consumer.drain_events | ||
| time_start = monotonic() | ||
| while 1: | ||
| # Total time spent may exceed a single call to wait() | ||
| if timeout and monotonic() - time_start >= timeout: | ||
| raise socket.timeout() | ||
| try: | ||
| yield self.wait_for(p, wait, timeout=1) | ||
| except socket.timeout: | ||
| pass | ||
| if on_interval: | ||
| on_interval() | ||
| if p.ready: # got event on the wanted channel. | ||
| break | ||
| def wait_for(self, p, wait, timeout=None): | ||
| wait(timeout=timeout) | ||
| class greenletDrainer(Drainer): | ||
| spawn = None | ||
| _g = None | ||
| def __init__(self, *args, **kwargs): | ||
| super(greenletDrainer, self).__init__(*args, **kwargs) | ||
| self._started = threading.Event() | ||
| self._stopped = threading.Event() | ||
| self._shutdown = threading.Event() | ||
| def run(self): | ||
| self._started.set() | ||
| while not self._stopped.is_set(): | ||
| try: | ||
| self.result_consumer.drain_events(timeout=1) | ||
| except socket.timeout: | ||
| pass | ||
| self._shutdown.set() | ||
| def start(self): | ||
| if not self._started.is_set(): | ||
| self._g = self.spawn(self.run) | ||
| self._started.wait() | ||
| def stop(self): | ||
| self._stopped.set() | ||
| self._shutdown.wait(THREAD_TIMEOUT_MAX) | ||
| def wait_for(self, p, wait, timeout=None): | ||
| self.start() | ||
| if not p.ready: | ||
| sleep(0) | ||
| @register_drainer('eventlet') | ||
| class eventletDrainer(greenletDrainer): | ||
| @cached_property | ||
| def spawn(self): | ||
| from eventlet import spawn | ||
| return spawn | ||
| @register_drainer('gevent') | ||
| class geventDrainer(greenletDrainer): | ||
| @cached_property | ||
| def spawn(self): | ||
| from gevent import spawn | ||
| return spawn | ||
| class AsyncBackendMixin(object): | ||
| """Mixin for backends that enables the async API.""" | ||
| def _collect_into(self, result, bucket): | ||
| self.result_consumer.buckets[result] = bucket | ||
| def iter_native(self, result, no_ack=True, **kwargs): | ||
| self._ensure_not_eager() | ||
| results = result.results | ||
| if not results: | ||
| raise StopIteration() | ||
| # we tell the result consumer to put consumed results | ||
| # into these buckets. | ||
| bucket = deque() | ||
| for node in results: | ||
| if node._cache: | ||
| bucket.append(node) | ||
| else: | ||
| self._collect_into(node, bucket) | ||
| for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs): | ||
| while bucket: | ||
| node = bucket.popleft() | ||
| yield node.id, node._cache | ||
| while bucket: | ||
| node = bucket.popleft() | ||
| yield node.id, node._cache | ||
| def add_pending_result(self, result, weak=False, start_drainer=True): | ||
| if start_drainer: | ||
| self.result_consumer.drainer.start() | ||
| try: | ||
| self._maybe_resolve_from_buffer(result) | ||
| except Empty: | ||
| self._add_pending_result(result.id, result, weak=weak) | ||
| return result | ||
| def _maybe_resolve_from_buffer(self, result): | ||
| result._maybe_set_cache(self._pending_messages.take(result.id)) | ||
| def _add_pending_result(self, task_id, result, weak=False): | ||
| concrete, weak_ = self._pending_results | ||
| if task_id not in weak_ and result.id not in concrete: | ||
| (weak_ if weak else concrete)[task_id] = result | ||
| self.result_consumer.consume_from(task_id) | ||
| def add_pending_results(self, results, weak=False): | ||
| self.result_consumer.drainer.start() | ||
| return [self.add_pending_result(result, weak=weak, start_drainer=False) | ||
| for result in results] | ||
| def remove_pending_result(self, result): | ||
| self._remove_pending_result(result.id) | ||
| self.on_result_fulfilled(result) | ||
| return result | ||
| def _remove_pending_result(self, task_id): | ||
| for map in self._pending_results: | ||
| map.pop(task_id, None) | ||
| def on_result_fulfilled(self, result): | ||
| self.result_consumer.cancel_for(result.id) | ||
| def wait_for_pending(self, result, | ||
| callback=None, propagate=True, **kwargs): | ||
| self._ensure_not_eager() | ||
| for _ in self._wait_for_pending(result, **kwargs): | ||
| pass | ||
| return result.maybe_throw(callback=callback, propagate=propagate) | ||
| def _wait_for_pending(self, result, | ||
| timeout=None, on_interval=None, on_message=None, | ||
| **kwargs): | ||
| return self.result_consumer._wait_for_pending( | ||
| result, timeout=timeout, | ||
| on_interval=on_interval, on_message=on_message, | ||
| ) | ||
| @property | ||
| def is_async(self): | ||
| return True | ||
| class BaseResultConsumer(object): | ||
| """Manager responsible for consuming result messages.""" | ||
| def __init__(self, backend, app, accept, | ||
| pending_results, pending_messages): | ||
| self.backend = backend | ||
| self.app = app | ||
| self.accept = accept | ||
| self._pending_results = pending_results | ||
| self._pending_messages = pending_messages | ||
| self.on_message = None | ||
| self.buckets = WeakKeyDictionary() | ||
| self.drainer = drainers[detect_environment()](self) | ||
| def start(self, initial_task_id, **kwargs): | ||
| raise NotImplementedError() | ||
| def stop(self): | ||
| pass | ||
| def drain_events(self, timeout=None): | ||
| raise NotImplementedError() | ||
| def consume_from(self, task_id): | ||
| raise NotImplementedError() | ||
| def cancel_for(self, task_id): | ||
| raise NotImplementedError() | ||
| def _after_fork(self): | ||
| self.buckets.clear() | ||
| self.buckets = WeakKeyDictionary() | ||
| self.on_message = None | ||
| self.on_after_fork() | ||
| def on_after_fork(self): | ||
| pass | ||
| def drain_events_until(self, p, timeout=None, on_interval=None): | ||
| return self.drainer.drain_events_until( | ||
| p, timeout=timeout, on_interval=on_interval) | ||
| def _wait_for_pending(self, result, | ||
| timeout=None, on_interval=None, on_message=None, | ||
| **kwargs): | ||
| self.on_wait_for_pending(result, timeout=timeout, **kwargs) | ||
| prev_on_m, self.on_message = self.on_message, on_message | ||
| try: | ||
| for _ in self.drain_events_until( | ||
| result.on_ready, timeout=timeout, | ||
| on_interval=on_interval): | ||
| yield | ||
| sleep(0) | ||
| except socket.timeout: | ||
| raise TimeoutError('The operation timed out.') | ||
| finally: | ||
| self.on_message = prev_on_m | ||
| def on_wait_for_pending(self, result, timeout=None, **kwargs): | ||
| pass | ||
| def on_out_of_band_result(self, message): | ||
| self.on_state_change(message.payload, message) | ||
| def _get_pending_result(self, task_id): | ||
| for mapping in self._pending_results: | ||
| try: | ||
| return mapping[task_id] | ||
| except KeyError: | ||
| pass | ||
| raise KeyError(task_id) | ||
| def on_state_change(self, meta, message): | ||
| if self.on_message: | ||
| self.on_message(meta) | ||
| if meta['status'] in states.READY_STATES: | ||
| task_id = meta['task_id'] | ||
| try: | ||
| result = self._get_pending_result(task_id) | ||
| except KeyError: | ||
| # send to buffer in case we received this result | ||
| # before it was added to _pending_results. | ||
| self._pending_messages.put(task_id, meta) | ||
| else: | ||
| result._maybe_set_cache(meta) | ||
| buckets = self.buckets | ||
| try: | ||
| # remove bucket for this result, since it's fulfilled | ||
| bucket = buckets.pop(result) | ||
| except KeyError: | ||
| pass | ||
| else: | ||
| # send to waiter via bucket | ||
| bucket.append(result) | ||
| sleep(0) |
| @@ -1,193 +1,236 @@ | ||
| # -* coding: utf-8 -*- | ||
| """ | ||
| celery.backends.cassandra | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
| Apache Cassandra result store backend. | ||
| """ | ||
| from __future__ import absolute_import | ||
| try: # pragma: no cover | ||
| import pycassa | ||
| from thrift import Thrift | ||
| C = pycassa.cassandra.ttypes | ||
| except ImportError: # pragma: no cover | ||
| pycassa = None # noqa | ||
| """Apache Cassandra result store backend using the DataStax driver.""" | ||
| from __future__ import absolute_import, unicode_literals | ||
| import socket | ||
| import time | ||
| import sys | ||
| from celery import states | ||
| from celery.exceptions import ImproperlyConfigured | ||
| from celery.five import monotonic | ||
| from celery.utils.log import get_logger | ||
| from celery.utils.timeutils import maybe_timedelta | ||
| from .base import BaseBackend | ||
| __all__ = ['CassandraBackend'] | ||
| try: # pragma: no cover | ||
| import cassandra | ||
| import cassandra.auth | ||
| import cassandra.cluster | ||
| except ImportError: # pragma: no cover | ||
| cassandra = None # noqa | ||
| __all__ = ('CassandraBackend',) | ||
| logger = get_logger(__name__) | ||
| E_NO_CASSANDRA = """ | ||
| You need to install the cassandra-driver library to | ||
| use the Cassandra backend. See https://github.com/datastax/python-driver | ||
| """ | ||
| class CassandraBackend(BaseBackend): | ||
| """Highly fault tolerant Cassandra backend. | ||
| E_NO_SUCH_CASSANDRA_AUTH_PROVIDER = """ | ||
| CASSANDRA_AUTH_PROVIDER you provided is not a valid auth_provider class. | ||
| See https://datastax.github.io/python-driver/api/cassandra/auth.html. | ||
| """ | ||
| Q_INSERT_RESULT = """ | ||
| INSERT INTO {table} ( | ||
| task_id, status, result, date_done, traceback, children) VALUES ( | ||
| %s, %s, %s, %s, %s, %s) {expires}; | ||
| """ | ||
| .. attribute:: servers | ||
| Q_SELECT_RESULT = """ | ||
| SELECT status, result, date_done, traceback, children | ||
| FROM {table} | ||
| WHERE task_id=%s | ||
| LIMIT 1 | ||
| """ | ||
| List of Cassandra servers with format: ``hostname:port``. | ||
| Q_CREATE_RESULT_TABLE = """ | ||
| CREATE TABLE {table} ( | ||
| task_id text, | ||
| status text, | ||
| result blob, | ||
| date_done timestamp, | ||
| traceback blob, | ||
| children blob, | ||
| PRIMARY KEY ((task_id), date_done) | ||
| ) WITH CLUSTERING ORDER BY (date_done DESC); | ||
| """ | ||
| Q_EXPIRES = """ | ||
| USING TTL {0} | ||
| """ | ||
| :raises celery.exceptions.ImproperlyConfigured: if | ||
| module :mod:`pycassa` is not available. | ||
| if sys.version_info[0] == 3: | ||
| def buf_t(x): | ||
| return bytes(x, 'utf8') | ||
| else: | ||
| buf_t = buffer # noqa | ||
| class CassandraBackend(BaseBackend): | ||
| """Cassandra backend utilizing DataStax driver. | ||
| Raises: | ||
| celery.exceptions.ImproperlyConfigured: | ||
| if module :pypi:`cassandra-driver` is not available, | ||
| or if the :setting:`cassandra_servers` setting is not set. | ||
| """ | ||
| servers = [] | ||
| keyspace = None | ||
| column_family = None | ||
| detailed_mode = False | ||
| _retry_timeout = 300 | ||
| _retry_wait = 3 | ||
| supports_autoexpire = True | ||
| def __init__(self, servers=None, keyspace=None, column_family=None, | ||
| cassandra_options=None, detailed_mode=False, **kwargs): | ||
| """Initialize Cassandra backend. | ||
| #: List of Cassandra servers with format: ``hostname``. | ||
| servers = None | ||
| Raises :class:`celery.exceptions.ImproperlyConfigured` if | ||
| the :setting:`CASSANDRA_SERVERS` setting is not set. | ||
| supports_autoexpire = True # autoexpire supported via entry_ttl | ||
| """ | ||
| def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None, | ||
| port=9042, **kwargs): | ||
| super(CassandraBackend, self).__init__(**kwargs) | ||
| self.expires = kwargs.get('expires') or maybe_timedelta( | ||
| self.app.conf.CELERY_TASK_RESULT_EXPIRES) | ||
| if not pycassa: | ||
| raise ImproperlyConfigured( | ||
| 'You need to install the pycassa library to use the ' | ||
| 'Cassandra backend. See https://github.com/pycassa/pycassa') | ||
| if not cassandra: | ||
| raise ImproperlyConfigured(E_NO_CASSANDRA) | ||
| conf = self.app.conf | ||
| self.servers = (servers or | ||
| conf.get('CASSANDRA_SERVERS') or | ||
| self.servers) | ||
| self.keyspace = (keyspace or | ||
| conf.get('CASSANDRA_KEYSPACE') or | ||
| self.keyspace) | ||
| self.column_family = (column_family or | ||
| conf.get('CASSANDRA_COLUMN_FAMILY') or | ||
| self.column_family) | ||
| self.cassandra_options = dict(conf.get('CASSANDRA_OPTIONS') or {}, | ||
| **cassandra_options or {}) | ||
| self.detailed_mode = (detailed_mode or | ||
| conf.get('CASSANDRA_DETAILED_MODE') or | ||
| self.detailed_mode) | ||
| read_cons = conf.get('CASSANDRA_READ_CONSISTENCY') or 'LOCAL_QUORUM' | ||
| write_cons = conf.get('CASSANDRA_WRITE_CONSISTENCY') or 'LOCAL_QUORUM' | ||
| try: | ||
| self.read_consistency = getattr(pycassa.ConsistencyLevel, | ||
| read_cons) | ||
| except AttributeError: | ||
| self.read_consistency = pycassa.ConsistencyLevel.LOCAL_QUORUM | ||
| try: | ||
| self.write_consistency = getattr(pycassa.ConsistencyLevel, | ||
| write_cons) | ||
| except AttributeError: | ||
| self.write_consistency = pycassa.ConsistencyLevel.LOCAL_QUORUM | ||
| if not self.servers or not self.keyspace or not self.column_family: | ||
| raise ImproperlyConfigured( | ||
| 'Cassandra backend not configured.') | ||
| self._column_family = None | ||
| def _retry_on_error(self, fun, *args, **kwargs): | ||
| ts = monotonic() + self._retry_timeout | ||
| while 1: | ||
| try: | ||
| return fun(*args, **kwargs) | ||
| except (pycassa.InvalidRequestException, | ||
| pycassa.TimedOutException, | ||
| pycassa.UnavailableException, | ||
| pycassa.AllServersUnavailable, | ||
| socket.error, | ||
| socket.timeout, | ||
| Thrift.TException) as exc: | ||
| if monotonic() > ts: | ||
| raise | ||
| logger.warning('Cassandra error: %r. Retrying...', exc) | ||
| time.sleep(self._retry_wait) | ||
| def _get_column_family(self): | ||
| if self._column_family is None: | ||
| conn = pycassa.ConnectionPool(self.keyspace, | ||
| server_list=self.servers, | ||
| **self.cassandra_options) | ||
| self._column_family = pycassa.ColumnFamily( | ||
| conn, self.column_family, | ||
| read_consistency_level=self.read_consistency, | ||
| write_consistency_level=self.write_consistency, | ||
| ) | ||
| return self._column_family | ||
| self.servers = servers or conf.get('cassandra_servers', None) | ||
| self.port = port or conf.get('cassandra_port', None) | ||
| self.keyspace = keyspace or conf.get('cassandra_keyspace', None) | ||
| self.table = table or conf.get('cassandra_table', None) | ||
| self.cassandra_options = conf.get('cassandra_options', {}) | ||
| if not self.servers or not self.keyspace or not self.table: | ||
| raise ImproperlyConfigured('Cassandra backend not configured.') | ||
| expires = entry_ttl or conf.get('cassandra_entry_ttl', None) | ||
| self.cqlexpires = ( | ||
| Q_EXPIRES.format(expires) if expires is not None else '') | ||
| read_cons = conf.get('cassandra_read_consistency') or 'LOCAL_QUORUM' | ||
| write_cons = conf.get('cassandra_write_consistency') or 'LOCAL_QUORUM' | ||
| self.read_consistency = getattr( | ||
| cassandra.ConsistencyLevel, read_cons, | ||
| cassandra.ConsistencyLevel.LOCAL_QUORUM) | ||
| self.write_consistency = getattr( | ||
| cassandra.ConsistencyLevel, write_cons, | ||
| cassandra.ConsistencyLevel.LOCAL_QUORUM) | ||
| self.auth_provider = None | ||
| auth_provider = conf.get('cassandra_auth_provider', None) | ||
| auth_kwargs = conf.get('cassandra_auth_kwargs', None) | ||
| if auth_provider and auth_kwargs: | ||
| auth_provider_class = getattr(cassandra.auth, auth_provider, None) | ||
| if not auth_provider_class: | ||
| raise ImproperlyConfigured(E_NO_SUCH_CASSANDRA_AUTH_PROVIDER) | ||
| self.auth_provider = auth_provider_class(**auth_kwargs) | ||
| self._connection = None | ||
| self._session = None | ||
| self._write_stmt = None | ||
| self._read_stmt = None | ||
| self._make_stmt = None | ||
| def process_cleanup(self): | ||
| if self._column_family is not None: | ||
| self._column_family = None | ||
| if self._connection is not None: | ||
| self._connection.shutdown() # also shuts down _session | ||
| self._connection = None | ||
| self._session = None | ||
| def _store_result(self, task_id, result, status, | ||
| def _get_connection(self, write=False): | ||
| """Prepare the connection for action. | ||
| Arguments: | ||
| write (bool): are we a writer? | ||
| """ | ||
| if self._connection is not None: | ||
| return | ||
| try: | ||
| self._connection = cassandra.cluster.Cluster( | ||
| self.servers, port=self.port, | ||
| auth_provider=self.auth_provider, | ||
| **self.cassandra_options) | ||
| self._session = self._connection.connect(self.keyspace) | ||
| # We're forced to do concatenation below, as formatting would | ||
| # blow up on superficial %s that'll be processed by Cassandra | ||
| self._write_stmt = cassandra.query.SimpleStatement( | ||
| Q_INSERT_RESULT.format( | ||
| table=self.table, expires=self.cqlexpires), | ||
| ) | ||
| self._write_stmt.consistency_level = self.write_consistency | ||
| self._read_stmt = cassandra.query.SimpleStatement( | ||
| Q_SELECT_RESULT.format(table=self.table), | ||
| ) | ||
| self._read_stmt.consistency_level = self.read_consistency | ||
| if write: | ||
| # Only possible writers "workers" are allowed to issue | ||
| # CREATE TABLE. This is to prevent conflicting situations | ||
| # where both task-creator and task-executor would issue it | ||
| # at the same time. | ||
| # Anyway; if you're doing anything critical, you should | ||
| # have created this table in advance, in which case | ||
| # this query will be a no-op (AlreadyExists) | ||
| self._make_stmt = cassandra.query.SimpleStatement( | ||
| Q_CREATE_RESULT_TABLE.format(table=self.table), | ||
| ) | ||
| self._make_stmt.consistency_level = self.write_consistency | ||
| try: | ||
| self._session.execute(self._make_stmt) | ||
| except cassandra.AlreadyExists: | ||
| pass | ||
| except cassandra.OperationTimedOut: | ||
| # a heavily loaded or gone Cassandra cluster failed to respond. | ||
| # leave this class in a consistent state | ||
| if self._connection is not None: | ||
| self._connection.shutdown() # also shuts down _session | ||
| self._connection = None | ||
| self._session = None | ||
| raise # we did fail after all - reraise | ||
| def _store_result(self, task_id, result, state, | ||
| traceback=None, request=None, **kwargs): | ||
| """Store return value and status of an executed task.""" | ||
| def _do_store(): | ||
| cf = self._get_column_family() | ||
| date_done = self.app.now() | ||
| meta = {'status': status, | ||
| 'date_done': date_done.strftime('%Y-%m-%dT%H:%M:%SZ'), | ||
| 'traceback': self.encode(traceback), | ||
| 'children': self.encode( | ||
| self.current_task_children(request), | ||
| )} | ||
| ttl = self.expires and max(self.expires.total_seconds(), 0) | ||
| if self.detailed_mode: | ||
| meta['result'] = result | ||
| cf.insert(task_id, {date_done: self.encode(meta)}, ttl=ttl) | ||
| else: | ||
| meta['result'] = self.encode(result) | ||
| cf.insert(task_id, meta, ttl=ttl) | ||
| return self._retry_on_error(_do_store) | ||
| """Store return value and state of an executed task.""" | ||
| self._get_connection(write=True) | ||
| self._session.execute(self._write_stmt, ( | ||
| task_id, | ||
| state, | ||
| buf_t(self.encode(result)), | ||
| self.app.now(), | ||
| buf_t(self.encode(traceback)), | ||
| buf_t(self.encode(self.current_task_children(request))) | ||
| )) | ||
| def as_uri(self, include_password=True): | ||
| return 'cassandra://' | ||
| def _get_task_meta_for(self, task_id): | ||
| """Get task metadata for a task by id.""" | ||
| def _do_get(): | ||
| cf = self._get_column_family() | ||
| try: | ||
| if self.detailed_mode: | ||
| row = cf.get(task_id, column_reversed=True, column_count=1) | ||
| meta = self.decode(list(row.values())[0]) | ||
| meta['task_id'] = task_id | ||
| else: | ||
| obj = cf.get(task_id) | ||
| meta = { | ||
| 'task_id': task_id, | ||
| 'status': obj['status'], | ||
| 'result': self.decode(obj['result']), | ||
| 'date_done': obj['date_done'], | ||
| 'traceback': self.decode(obj['traceback']), | ||
| 'children': self.decode(obj['children']), | ||
| } | ||
| except (KeyError, pycassa.NotFoundException): | ||
| meta = {'status': states.PENDING, 'result': None} | ||
| return meta | ||
| return self._retry_on_error(_do_get) | ||
| """Get task meta-data for a task by id.""" | ||
| self._get_connection() | ||
| res = self._session.execute(self._read_stmt, (task_id, )) | ||
| if not res: | ||
| return {'status': states.PENDING, 'result': None} | ||
| status, result, date_done, traceback, children = res[0] | ||
| return self.meta_from_decoded({ | ||
| 'task_id': task_id, | ||
| 'status': status, | ||
| 'result': self.decode(result), | ||
| 'date_done': date_done.strftime('%Y-%m-%dT%H:%M:%SZ'), | ||
| 'traceback': self.decode(traceback), | ||
| 'children': self.decode(children), | ||
| }) | ||
| def __reduce__(self, args=(), kwargs={}): | ||
| kwargs.update( | ||
| dict(servers=self.servers, | ||
| keyspace=self.keyspace, | ||
| column_family=self.column_family, | ||
| cassandra_options=self.cassandra_options)) | ||
| {'servers': self.servers, | ||
| 'keyspace': self.keyspace, | ||
| 'table': self.table}) | ||
| return super(CassandraBackend, self).__reduce__(args, kwargs) |
| @@ -0,0 +1,103 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """Consul result store backend. | ||
| - :class:`ConsulBackend` implements KeyValueStoreBackend to store results | ||
| in the key-value store of Consul. | ||
| """ | ||
| from __future__ import absolute_import, unicode_literals | ||
| from kombu.utils.encoding import bytes_to_str | ||
| from kombu.utils.url import parse_url | ||
| from celery.backends.base import KeyValueStoreBackend | ||
| from celery.exceptions import ImproperlyConfigured | ||
| from celery.utils.log import get_logger | ||
| try: | ||
| import consul | ||
| except ImportError: | ||
| consul = None | ||
| logger = get_logger(__name__) | ||
| __all__ = ('ConsulBackend',) | ||
| CONSUL_MISSING = """\ | ||
| You need to install the python-consul library in order to use \ | ||
| the Consul result store backend.""" | ||
| class ConsulBackend(KeyValueStoreBackend): | ||
| """Consul.io K/V store backend for Celery.""" | ||
| consul = consul | ||
| supports_autoexpire = True | ||
| client = None | ||
| consistency = 'consistent' | ||
| path = None | ||
| def __init__(self, *args, **kwargs): | ||
| super(ConsulBackend, self).__init__(*args, **kwargs) | ||
| if self.consul is None: | ||
| raise ImproperlyConfigured(CONSUL_MISSING) | ||
| self._init_from_params(**parse_url(self.url)) | ||
| def _init_from_params(self, hostname, port, virtual_host, **params): | ||
| logger.debug('Setting on Consul client to connect to %s:%d', | ||
| hostname, port) | ||
| self.path = virtual_host | ||
| self.client = consul.Consul(host=hostname, port=port, | ||
| consistency=self.consistency) | ||
| def _key_to_consul_key(self, key): | ||
| key = bytes_to_str(key) | ||
| return key if self.path is None else '{0}/{1}'.format(self.path, key) | ||
| def get(self, key): | ||
| key = self._key_to_consul_key(key) | ||
| logger.debug('Trying to fetch key %s from Consul', key) | ||
| try: | ||
| _, data = self.client.kv.get(key) | ||
| return data['Value'] | ||
| except TypeError: | ||
| pass | ||
| def mget(self, keys): | ||
| for key in keys: | ||
| yield self.get(key) | ||
| def set(self, key, value): | ||
| """Set a key in Consul. | ||
| Before creating the key it will create a session inside Consul | ||
| where it creates a session with a TTL | ||
| The key created afterwards will reference to the session's ID. | ||
| If the session expires it will remove the key so that results | ||
| can auto expire from the K/V store | ||
| """ | ||
| session_name = bytes_to_str(key) | ||
| key = self._key_to_consul_key(key) | ||
| logger.debug('Trying to create Consul session %s with TTL %d', | ||
| session_name, self.expires) | ||
| session_id = self.client.session.create(name=session_name, | ||
| behavior='delete', | ||
| ttl=self.expires) | ||
| logger.debug('Created Consul session %s', session_id) | ||
| logger.debug('Writing key %s to Consul', key) | ||
| return self.client.kv.put(key=key, | ||
| value=value, | ||
| acquire=session_id) | ||
| def delete(self, key): | ||
| key = self._key_to_consul_key(key) | ||
| logger.debug('Removing key %s from Consul', key) | ||
| return self.client.kv.delete(key) |
| @@ -0,0 +1,104 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """CouchDB result store backend.""" | ||
| from __future__ import absolute_import, unicode_literals | ||
| from kombu.utils.encoding import bytes_to_str | ||
| from kombu.utils.url import _parse_url | ||
| from celery.exceptions import ImproperlyConfigured | ||
| from .base import KeyValueStoreBackend | ||
| try: | ||
| import pycouchdb | ||
| except ImportError: | ||
| pycouchdb = None # noqa | ||
| __all__ = ('CouchBackend',) | ||
| ERR_LIB_MISSING = """\ | ||
| You need to install the pycouchdb library to use the CouchDB result backend\ | ||
| """ | ||
| class CouchBackend(KeyValueStoreBackend): | ||
| """CouchDB backend. | ||
| Raises: | ||
| celery.exceptions.ImproperlyConfigured: | ||
| if module :pypi:`pycouchdb` is not available. | ||
| """ | ||
| container = 'default' | ||
| scheme = 'http' | ||
| host = 'localhost' | ||
| port = 5984 | ||
| username = None | ||
| password = None | ||
| def __init__(self, url=None, *args, **kwargs): | ||
| super(CouchBackend, self).__init__(*args, **kwargs) | ||
| self.url = url | ||
| if pycouchdb is None: | ||
| raise ImproperlyConfigured(ERR_LIB_MISSING) | ||
| uscheme = uhost = uport = uname = upass = ucontainer = None | ||
| if url: | ||
| _, uhost, uport, uname, upass, ucontainer, _ = _parse_url(url) # noqa | ||
| ucontainer = ucontainer.strip('/') if ucontainer else None | ||
| self.scheme = uscheme or self.scheme | ||
| self.host = uhost or self.host | ||
| self.port = int(uport or self.port) | ||
| self.container = ucontainer or self.container | ||
| self.username = uname or self.username | ||
| self.password = upass or self.password | ||
| self._connection = None | ||
| def _get_connection(self): | ||
| """Connect to the CouchDB server.""" | ||
| if self.username and self.password: | ||
| conn_string = '%s://%s:%s@%s:%s' % ( | ||
| self.scheme, self.username, self.password, | ||
| self.host, str(self.port)) | ||
| server = pycouchdb.Server(conn_string, authmethod='basic') | ||
| else: | ||
| conn_string = '%s://%s:%s' % ( | ||
| self.scheme, self.host, str(self.port)) | ||
| server = pycouchdb.Server(conn_string) | ||
| try: | ||
| return server.database(self.container) | ||
| except pycouchdb.exceptions.NotFound: | ||
| return server.create(self.container) | ||
| @property | ||
| def connection(self): | ||
| if self._connection is None: | ||
| self._connection = self._get_connection() | ||
| return self._connection | ||
| def get(self, key): | ||
| try: | ||
| return self.connection.get(key)['value'] | ||
| except pycouchdb.exceptions.NotFound: | ||
| return None | ||
| def set(self, key, value): | ||
| key = bytes_to_str(key) | ||
| data = {'_id': key, 'value': value} | ||
| try: | ||
| self.connection.save(data) | ||
| except pycouchdb.exceptions.Conflict: | ||
| # document already exists, update it | ||
| data = self.connection.get(key) | ||
| data['value'] = value | ||
| self.connection.save(data) | ||
| def mget(self, keys): | ||
| return [self.get(key) for key in keys] | ||
| def delete(self, key): | ||
| self.connection.delete(key) |
| @@ -0,0 +1,285 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """AWS DynamoDB result store backend.""" | ||
| from __future__ import absolute_import, unicode_literals | ||
| from collections import namedtuple | ||
| from time import sleep, time | ||
| from kombu.utils.url import _parse_url as parse_url | ||
| from celery.exceptions import ImproperlyConfigured | ||
| from celery.five import string | ||
| from celery.utils.log import get_logger | ||
| from .base import KeyValueStoreBackend | ||
| try: | ||
| import boto3 | ||
| from botocore.exceptions import ClientError | ||
| except ImportError: # pragma: no cover | ||
| boto3 = ClientError = None # noqa | ||
| __all__ = ('DynamoDBBackend',) | ||
| # Helper class that describes a DynamoDB attribute | ||
| DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type')) | ||
| logger = get_logger(__name__) | ||
| class DynamoDBBackend(KeyValueStoreBackend): | ||
| """AWS DynamoDB result backend. | ||
| Raises: | ||
| celery.exceptions.ImproperlyConfigured: | ||
| if module :pypi:`boto3` is not available. | ||
| """ | ||
| #: default DynamoDB table name (`default`) | ||
| table_name = 'celery' | ||
| #: Read Provisioned Throughput (`default`) | ||
| read_capacity_units = 1 | ||
| #: Write Provisioned Throughput (`default`) | ||
| write_capacity_units = 1 | ||
| #: AWS region (`default`) | ||
| aws_region = None | ||
| #: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`) | ||
| endpoint_url = None | ||
| _key_field = DynamoDBAttribute(name='id', data_type='S') | ||
| _value_field = DynamoDBAttribute(name='result', data_type='B') | ||
| _timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N') | ||
| _available_fields = None | ||
| def __init__(self, url=None, table_name=None, *args, **kwargs): | ||
| super(DynamoDBBackend, self).__init__(*args, **kwargs) | ||
| self.url = url | ||
| self.table_name = table_name or self.table_name | ||
| if not boto3: | ||
| raise ImproperlyConfigured( | ||
| 'You need to install the boto3 library to use the ' | ||
| 'DynamoDB backend.') | ||
| aws_credentials_given = False | ||
| aws_access_key_id = None | ||
| aws_secret_access_key = None | ||
| if url is not None: | ||
| scheme, region, port, username, password, table, query = \ | ||
| parse_url(url) | ||
| aws_access_key_id = username | ||
| aws_secret_access_key = password | ||
| access_key_given = aws_access_key_id is not None | ||
| secret_key_given = aws_secret_access_key is not None | ||
| if access_key_given != secret_key_given: | ||
| raise ImproperlyConfigured( | ||
| 'You need to specify both the Access Key ID ' | ||
| 'and Secret.') | ||
| aws_credentials_given = access_key_given | ||
| if region == 'localhost': | ||
| # We are using the downloadable, local version of DynamoDB | ||
| self.endpoint_url = 'http://localhost:{}'.format(port) | ||
| self.aws_region = 'us-east-1' | ||
| logger.warning( | ||
| 'Using local-only DynamoDB endpoint URL: {}'.format( | ||
| self.endpoint_url | ||
| ) | ||
| ) | ||
| else: | ||
| self.aws_region = region | ||
| # If endpoint_url is explicitly set use it instead | ||
| _get = self.app.conf.get | ||
| config_endpoint_url = _get('dynamodb_endpoint_url') | ||
| if config_endpoint_url: | ||
| self.endpoint_url = config_endpoint_url | ||
| self.read_capacity_units = int( | ||
| query.get( | ||
| 'read', | ||
| self.read_capacity_units | ||
| ) | ||
| ) | ||
| self.write_capacity_units = int( | ||
| query.get( | ||
| 'write', | ||
| self.write_capacity_units | ||
| ) | ||
| ) | ||
| self.table_name = table or self.table_name | ||
| self._available_fields = ( | ||
| self._key_field, | ||
| self._value_field, | ||
| self._timestamp_field | ||
| ) | ||
| self._client = None | ||
| if aws_credentials_given: | ||
| self._get_client( | ||
| access_key_id=aws_access_key_id, | ||
| secret_access_key=aws_secret_access_key | ||
| ) | ||
| def _get_client(self, access_key_id=None, secret_access_key=None): | ||
| """Get client connection.""" | ||
| if self._client is None: | ||
| client_parameters = { | ||
| 'region_name': self.aws_region | ||
| } | ||
| if access_key_id is not None: | ||
| client_parameters.update({ | ||
| 'aws_access_key_id': access_key_id, | ||
| 'aws_secret_access_key': secret_access_key | ||
| }) | ||
| if self.endpoint_url is not None: | ||
| client_parameters['endpoint_url'] = self.endpoint_url | ||
| self._client = boto3.client( | ||
| 'dynamodb', | ||
| **client_parameters | ||
| ) | ||
| self._get_or_create_table() | ||
| return self._client | ||
| def _get_table_schema(self): | ||
| """Get the boto3 structure describing the DynamoDB table schema.""" | ||
| return { | ||
| 'AttributeDefinitions': [ | ||
| { | ||
| 'AttributeName': self._key_field.name, | ||
| 'AttributeType': self._key_field.data_type | ||
| } | ||
| ], | ||
| 'TableName': self.table_name, | ||
| 'KeySchema': [ | ||
| { | ||
| 'AttributeName': self._key_field.name, | ||
| 'KeyType': 'HASH' | ||
| } | ||
| ], | ||
| 'ProvisionedThroughput': { | ||
| 'ReadCapacityUnits': self.read_capacity_units, | ||
| 'WriteCapacityUnits': self.write_capacity_units | ||
| } | ||
| } | ||
| def _get_or_create_table(self): | ||
| """Create table if not exists, otherwise return the description.""" | ||
| table_schema = self._get_table_schema() | ||
| try: | ||
| table_description = self._client.create_table(**table_schema) | ||
| logger.info( | ||
| 'DynamoDB Table {} did not exist, creating.'.format( | ||
| self.table_name | ||
| ) | ||
| ) | ||
| # In case we created the table, wait until it becomes available. | ||
| self._wait_for_table_status('ACTIVE') | ||
| logger.info( | ||
| 'DynamoDB Table {} is now available.'.format( | ||
| self.table_name | ||
| ) | ||
| ) | ||
| return table_description | ||
| except ClientError as e: | ||
| error_code = e.response['Error'].get('Code', 'Unknown') | ||
| # If table exists, do not fail, just return the description. | ||
| if error_code == 'ResourceInUseException': | ||
| return self._client.describe_table( | ||
| TableName=self.table_name | ||
| ) | ||
| else: | ||
| raise e | ||
| def _wait_for_table_status(self, expected='ACTIVE'): | ||
| """Poll for the expected table status.""" | ||
| achieved_state = False | ||
| while not achieved_state: | ||
| table_description = self.client.describe_table( | ||
| TableName=self.table_name | ||
| ) | ||
| logger.debug( | ||
| 'Waiting for DynamoDB table {} to become {}.'.format( | ||
| self.table_name, | ||
| expected | ||
| ) | ||
| ) | ||
| current_status = table_description['Table']['TableStatus'] | ||
| achieved_state = current_status == expected | ||
| sleep(1) | ||
| def _prepare_get_request(self, key): | ||
| """Construct the item retrieval request parameters.""" | ||
| return { | ||
| 'TableName': self.table_name, | ||
| 'Key': { | ||
| self._key_field.name: { | ||
| self._key_field.data_type: key | ||
| } | ||
| } | ||
| } | ||
| def _prepare_put_request(self, key, value): | ||
| """Construct the item creation request parameters.""" | ||
| return { | ||
| 'TableName': self.table_name, | ||
| 'Item': { | ||
| self._key_field.name: { | ||
| self._key_field.data_type: key | ||
| }, | ||
| self._value_field.name: { | ||
| self._value_field.data_type: value | ||
| }, | ||
| self._timestamp_field.name: { | ||
| self._timestamp_field.data_type: str(time()) | ||
| } | ||
| } | ||
| } | ||
| def _item_to_dict(self, raw_response): | ||
| """Convert get_item() response to field-value pairs.""" | ||
| if 'Item' not in raw_response: | ||
| return {} | ||
| return { | ||
| field.name: raw_response['Item'][field.name][field.data_type] | ||
| for field in self._available_fields | ||
| } | ||
| @property | ||
| def client(self): | ||
| return self._get_client() | ||
| def get(self, key): | ||
| key = string(key) | ||
| request_parameters = self._prepare_get_request(key) | ||
| item_response = self.client.get_item(**request_parameters) | ||
| item = self._item_to_dict(item_response) | ||
| return item.get(self._value_field.name) | ||
| def set(self, key, value): | ||
| key = string(key) | ||
| request_parameters = self._prepare_put_request(key, value) | ||
| self.client.put_item(**request_parameters) | ||
| def mget(self, keys): | ||
| return [self.get(key) for key in keys] | ||
| def delete(self, key): | ||
| key = string(key) | ||
| request_parameters = self._prepare_get_request(key) | ||
| self.client.delete_item(**request_parameters) |
| @@ -0,0 +1,142 @@ | ||
| # -* coding: utf-8 -*- | ||
| """Elasticsearch result store backend.""" | ||
| from __future__ import absolute_import, unicode_literals | ||
| from datetime import datetime | ||
| from kombu.utils.encoding import bytes_to_str | ||
| from kombu.utils.url import _parse_url | ||
| from celery.exceptions import ImproperlyConfigured | ||
| from celery.five import items | ||
| from .base import KeyValueStoreBackend | ||
| try: | ||
| import elasticsearch | ||
| except ImportError: | ||
| elasticsearch = None # noqa | ||
| __all__ = ('ElasticsearchBackend',) | ||
| E_LIB_MISSING = """\ | ||
| You need to install the elasticsearch library to use the Elasticsearch \ | ||
| result backend.\ | ||
| """ | ||
| class ElasticsearchBackend(KeyValueStoreBackend): | ||
| """Elasticsearch Backend. | ||
| Raises: | ||
| celery.exceptions.ImproperlyConfigured: | ||
| if module :pypi:`elasticsearch` is not available. | ||
| """ | ||
| index = 'celery' | ||
| doc_type = 'backend' | ||
| scheme = 'http' | ||
| host = 'localhost' | ||
| port = 9200 | ||
| es_retry_on_timeout = False | ||
| es_timeout = 10 | ||
| es_max_retries = 3 | ||
| def __init__(self, url=None, *args, **kwargs): | ||
| super(ElasticsearchBackend, self).__init__(*args, **kwargs) | ||
| self.url = url | ||
| _get = self.app.conf.get | ||
| if elasticsearch is None: | ||
| raise ImproperlyConfigured(E_LIB_MISSING) | ||
| index = doc_type = scheme = host = port = None | ||
| if url: | ||
| scheme, host, port, _, _, path, _ = _parse_url(url) # noqa | ||
| if path: | ||
| path = path.strip('/') | ||
| index, _, doc_type = path.partition('/') | ||
| self.index = index or self.index | ||
| self.doc_type = doc_type or self.doc_type | ||
| self.scheme = scheme or self.scheme | ||
| self.host = host or self.host | ||
| self.port = port or self.port | ||
| self.es_retry_on_timeout = ( | ||
| _get('elasticsearch_retry_on_timeout') or self.es_retry_on_timeout | ||
| ) | ||
| es_timeout = _get('elasticsearch_timeout') | ||
| if es_timeout is not None: | ||
| self.es_timeout = es_timeout | ||
| es_max_retries = _get('elasticsearch_max_retries') | ||
| if es_max_retries is not None: | ||
| self.es_max_retries = es_max_retries | ||
| self._server = None | ||
| def get(self, key): | ||
| try: | ||
| res = self.server.get( | ||
| index=self.index, | ||
| doc_type=self.doc_type, | ||
| id=key, | ||
| ) | ||
| try: | ||
| if res['found']: | ||
| return res['_source']['result'] | ||
| except (TypeError, KeyError): | ||
| pass | ||
| except elasticsearch.exceptions.NotFoundError: | ||
| pass | ||
| def set(self, key, value): | ||
| try: | ||
| self._index( | ||
| id=key, | ||
| body={ | ||
| 'result': value, | ||
| '@timestamp': '{0}Z'.format( | ||
| datetime.utcnow().isoformat()[:-3] | ||
| ), | ||
| }, | ||
| ) | ||
| except elasticsearch.exceptions.ConflictError: | ||
| # document already exists, update it | ||
| data = self.get(key) | ||
| data[key] = value | ||
| self._index(key, data, refresh=True) | ||
| def _index(self, id, body, **kwargs): | ||
| body = {bytes_to_str(k): v for k, v in items(body)} | ||
| return self.server.index( | ||
| id=bytes_to_str(id), | ||
| index=self.index, | ||
| doc_type=self.doc_type, | ||
| body=body, | ||
| **kwargs | ||
| ) | ||
| def mget(self, keys): | ||
| return [self.get(key) for key in keys] | ||
| def delete(self, key): | ||
| self.server.delete(index=self.index, doc_type=self.doc_type, id=key) | ||
| def _get_server(self): | ||
| """Connect to the Elasticsearch server.""" | ||
| return elasticsearch.Elasticsearch( | ||
| '%s:%s' % (self.host, self.port), | ||
| retry_on_timeout=self.es_retry_on_timeout, | ||
| max_retries=self.es_max_retries, | ||
| timeout=self.es_timeout | ||
| ) | ||
| @property | ||
| def server(self): | ||
| if self._server is None: | ||
| self._server = self._get_server() | ||
| return self._server |
| @@ -0,0 +1,91 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """File-system result store backend.""" | ||
| from __future__ import absolute_import, unicode_literals | ||
| import locale | ||
| import os | ||
| from kombu.utils.encoding import ensure_bytes | ||
| from celery import uuid | ||
| from celery.backends.base import KeyValueStoreBackend | ||
| from celery.exceptions import ImproperlyConfigured | ||
| # Python 2 does not have FileNotFoundError and IsADirectoryError | ||
| try: | ||
| FileNotFoundError | ||
| except NameError: | ||
| FileNotFoundError = IOError | ||
| IsADirectoryError = IOError | ||
| default_encoding = locale.getpreferredencoding(False) | ||
| E_PATH_INVALID = """\ | ||
| The configured path for the file-system backend does not | ||
| work correctly, please make sure that it exists and has | ||
| the correct permissions.\ | ||
| """ | ||
| class FilesystemBackend(KeyValueStoreBackend): | ||
| """File-system result backend. | ||
| Arguments: | ||
| url (str): URL to the directory we should use | ||
| open (Callable): open function to use when opening files | ||
| unlink (Callable): unlink function to use when deleting files | ||
| sep (str): directory separator (to join the directory with the key) | ||
| encoding (str): encoding used on the file-system | ||
| """ | ||
| def __init__(self, url=None, open=open, unlink=os.unlink, sep=os.sep, | ||
| encoding=default_encoding, *args, **kwargs): | ||
| super(FilesystemBackend, self).__init__(*args, **kwargs) | ||
| self.url = url | ||
| path = self._find_path(url) | ||
| # We need the path and separator as bytes objects | ||
| self.path = path.encode(encoding) | ||
| self.sep = sep.encode(encoding) | ||
| self.open = open | ||
| self.unlink = unlink | ||
| # Lets verify that we've everything setup right | ||
| self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding)) | ||
| def _find_path(self, url): | ||
| if not url: | ||
| raise ImproperlyConfigured( | ||
| 'You need to configure a path for the File-system backend') | ||
| if url is not None and url.startswith('file:///'): | ||
| return url[7:] | ||
| def _do_directory_test(self, key): | ||
| try: | ||
| self.set(key, b'test value') | ||
| assert self.get(key) == b'test value' | ||
| self.delete(key) | ||
| except IOError: | ||
| raise ImproperlyConfigured(E_PATH_INVALID) | ||
| def _filename(self, key): | ||
| return self.sep.join((self.path, key)) | ||
| def get(self, key): | ||
| try: | ||
| with self.open(self._filename(key), 'rb') as infile: | ||
| return infile.read() | ||
| except FileNotFoundError: | ||
| pass | ||
| def set(self, key, value): | ||
| with self.open(self._filename(key), 'wb') as outfile: | ||
| outfile.write(ensure_bytes(value)) | ||
| def mget(self, keys): | ||
| for key in keys: | ||
| yield self.get(key) | ||
| def delete(self, key): | ||
| self.unlink(self._filename(key)) |
| @@ -1,64 +1,345 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """The ``RPC`` result backend for AMQP brokers. | ||
| RPC-style result backend, using reply-to and one queue per client. | ||
| """ | ||
| celery.backends.rpc | ||
| ~~~~~~~~~~~~~~~~~~~ | ||
| from __future__ import absolute_import, unicode_literals | ||
| import time | ||
| import kombu | ||
| from kombu.common import maybe_declare | ||
| from kombu.utils.compat import register_after_fork | ||
| from kombu.utils.objects import cached_property | ||
| from celery import states | ||
| from celery._state import current_task, task_join_will_block | ||
| from celery.five import items, range | ||
| from . import base | ||
| from .async import AsyncBackendMixin, BaseResultConsumer | ||
| __all__ = ('BacklogLimitExceeded', 'RPCBackend') | ||
| E_NO_CHORD_SUPPORT = """ | ||
| The "rpc" result backend does not support chords! | ||
| RPC-style result backend, using reply-to and one queue per client. | ||
| Note that a group chained with a task is also upgraded to be a chord, | ||
| as this pattern requires synchronization. | ||
| Result backends that supports chords: Redis, Database, Memcached, and more. | ||
| """ | ||
| from __future__ import absolute_import | ||
| from kombu import Consumer, Exchange | ||
| from kombu.common import maybe_declare | ||
| from kombu.utils import cached_property | ||
| from celery import current_task | ||
| from celery.backends import amqp | ||
| class BacklogLimitExceeded(Exception): | ||
| """Too much state history to fast-forward.""" | ||
| def _on_after_fork_cleanup_backend(backend): | ||
| backend._after_fork() | ||
| class ResultConsumer(BaseResultConsumer): | ||
| Consumer = kombu.Consumer | ||
| _connection = None | ||
| _consumer = None | ||
| def __init__(self, *args, **kwargs): | ||
| super(ResultConsumer, self).__init__(*args, **kwargs) | ||
| self._create_binding = self.backend._create_binding | ||
| __all__ = ['RPCBackend'] | ||
| def start(self, initial_task_id, no_ack=True, **kwargs): | ||
| self._connection = self.app.connection() | ||
| initial_queue = self._create_binding(initial_task_id) | ||
| self._consumer = self.Consumer( | ||
| self._connection.default_channel, [initial_queue], | ||
| callbacks=[self.on_state_change], no_ack=no_ack, | ||
| accept=self.accept) | ||
| self._consumer.consume() | ||
| def drain_events(self, timeout=None): | ||
| if self._connection: | ||
| return self._connection.drain_events(timeout=timeout) | ||
| elif timeout: | ||
| time.sleep(timeout) | ||
| def stop(self): | ||
| try: | ||
| self._consumer.cancel() | ||
| finally: | ||
| self._connection.close() | ||
| def on_after_fork(self): | ||
| self._consumer = None | ||
| if self._connection is not None: | ||
| self._connection.collect() | ||
| self._connection = None | ||
| def consume_from(self, task_id): | ||
| if self._consumer is None: | ||
| return self.start(task_id) | ||
| queue = self._create_binding(task_id) | ||
| if not self._consumer.consuming_from(queue): | ||
| self._consumer.add_queue(queue) | ||
| self._consumer.consume() | ||
| def cancel_for(self, task_id): | ||
| if self._consumer: | ||
| self._consumer.cancel_by_queue(self._create_binding(task_id).name) | ||
| class RPCBackend(base.Backend, AsyncBackendMixin): | ||
| """Base class for the RPC result backend.""" | ||
| Exchange = kombu.Exchange | ||
| Producer = kombu.Producer | ||
| ResultConsumer = ResultConsumer | ||
| #: Exception raised when there are too many messages for a task id. | ||
| BacklogLimitExceeded = BacklogLimitExceeded | ||
| class RPCBackend(amqp.AMQPBackend): | ||
| persistent = False | ||
| supports_autoexpire = True | ||
| supports_native_join = True | ||
| retry_policy = { | ||
| 'max_retries': 20, | ||
| 'interval_start': 0, | ||
| 'interval_step': 1, | ||
| 'interval_max': 1, | ||
| } | ||
| class Consumer(kombu.Consumer): | ||
| """Consumer that requires manual declaration of queues.""" | ||
| class Consumer(Consumer): | ||
| auto_declare = False | ||
| class Queue(kombu.Queue): | ||
| """Queue that never caches declaration.""" | ||
| can_cache_declaration = False | ||
| def __init__(self, app, connection=None, exchange=None, exchange_type=None, | ||
| persistent=None, serializer=None, auto_delete=True, **kwargs): | ||
| super(RPCBackend, self).__init__(app, **kwargs) | ||
| conf = self.app.conf | ||
| self._connection = connection | ||
| self._out_of_band = {} | ||
| self.persistent = self.prepare_persistent(persistent) | ||
| self.delivery_mode = 2 if self.persistent else 1 | ||
| exchange = exchange or conf.result_exchange | ||
| exchange_type = exchange_type or conf.result_exchange_type | ||
| self.exchange = self._create_exchange( | ||
| exchange, exchange_type, self.delivery_mode, | ||
| ) | ||
| self.serializer = serializer or conf.result_serializer | ||
| self.auto_delete = auto_delete | ||
| self.result_consumer = self.ResultConsumer( | ||
| self, self.app, self.accept, | ||
| self._pending_results, self._pending_messages, | ||
| ) | ||
| if register_after_fork is not None: | ||
| register_after_fork(self, _on_after_fork_cleanup_backend) | ||
| def _after_fork(self): | ||
| # clear state for child processes. | ||
| self._pending_results.clear() | ||
| self.result_consumer._after_fork() | ||
| def _create_exchange(self, name, type='direct', delivery_mode=2): | ||
| # uses direct to queue routing (anon exchange). | ||
| return Exchange(None) | ||
| def on_task_call(self, producer, task_id): | ||
| maybe_declare(self.binding(producer.channel), retry=True) | ||
| return self.Exchange(None) | ||
| def _create_binding(self, task_id): | ||
| """Create new binding for task with id.""" | ||
| # RPC backend caches the binding, as one queue is used for all tasks. | ||
| return self.binding | ||
| def _many_bindings(self, ids): | ||
| return [self.binding] | ||
| def ensure_chords_allowed(self): | ||
| raise NotImplementedError(E_NO_CHORD_SUPPORT.strip()) | ||
| def rkey(self, task_id): | ||
| return task_id | ||
| def on_task_call(self, producer, task_id): | ||
| # Called every time a task is sent when using this backend. | ||
| # We declare the queue we receive replies on in advance of sending | ||
| # the message, but we skip this if running in the prefork pool | ||
| # (task_join_will_block), as we know the queue is already declared. | ||
| if not task_join_will_block(): | ||
| maybe_declare(self.binding(producer.channel), retry=True) | ||
| def destination_for(self, task_id, request): | ||
| # Request is a new argument for backends, so must still support | ||
| # old code that rely on current_task | ||
| """Get the destination for result by task id. | ||
| Returns: | ||
| Tuple[str, str]: tuple of ``(reply_to, correlation_id)``. | ||
| """ | ||
| # Backends didn't always receive the `request`, so we must still | ||
| # support old code that relies on current_task. | ||
| try: | ||
| request = request or current_task.request | ||
| except AttributeError: | ||
| raise RuntimeError( | ||
| 'RPC backend missing task request for {0!r}'.format(task_id), | ||
| ) | ||
| 'RPC backend missing task request for {0!r}'.format(task_id)) | ||
| return request.reply_to, request.correlation_id or task_id | ||
| def on_reply_declare(self, task_id): | ||
| # Return value here is used as the `declare=` argument | ||
| # for Producer.publish. | ||
| # By default we don't have to declare anything when sending a result. | ||
| pass | ||
| def on_result_fulfilled(self, result): | ||
| # This usually cancels the queue after the result is received, | ||
| # but we don't have to cancel since we have one queue per process. | ||
| pass | ||
| def as_uri(self, include_password=True): | ||
| return 'rpc://' | ||
| def store_result(self, task_id, result, state, | ||
| traceback=None, request=None, **kwargs): | ||
| """Send task return value and state.""" | ||
| routing_key, correlation_id = self.destination_for(task_id, request) | ||
| if not routing_key: | ||
| return | ||
| with self.app.amqp.producer_pool.acquire(block=True) as producer: | ||
| producer.publish( | ||
| self._to_result(task_id, state, result, traceback, request), | ||
| exchange=self.exchange, | ||
| routing_key=routing_key, | ||
| correlation_id=correlation_id, | ||
| serializer=self.serializer, | ||
| retry=True, retry_policy=self.retry_policy, | ||
| declare=self.on_reply_declare(task_id), | ||
| delivery_mode=self.delivery_mode, | ||
| ) | ||
| return result | ||
| def _to_result(self, task_id, state, result, traceback, request): | ||
| return { | ||
| 'task_id': task_id, | ||
| 'status': state, | ||
| 'result': self.encode_result(result, state), | ||
| 'traceback': traceback, | ||
| 'children': self.current_task_children(request), | ||
| } | ||
| def on_out_of_band_result(self, task_id, message): | ||
| # Callback called when a reply for a task is received, | ||
| # but we have no idea what do do with it. | ||
| # Since the result is not pending, we put it in a separate | ||
| # buffer: probably it will become pending later. | ||
| if self.result_consumer: | ||
| self.result_consumer.on_out_of_band_result(message) | ||
| self._out_of_band[task_id] = message | ||
| def get_task_meta(self, task_id, backlog_limit=1000): | ||
| buffered = self._out_of_band.pop(task_id, None) | ||
| if buffered: | ||
| return self._set_cache_by_message(task_id, buffered) | ||
| # Polling and using basic_get | ||
| latest_by_id = {} | ||
| prev = None | ||
| for acc in self._slurp_from_queue(task_id, self.accept, backlog_limit): | ||
| tid = self._get_message_task_id(acc) | ||
| prev, latest_by_id[tid] = latest_by_id.get(tid), acc | ||
| if prev: | ||
| # backends aren't expected to keep history, | ||
| # so we delete everything except the most recent state. | ||
| prev.ack() | ||
| prev = None | ||
| latest = latest_by_id.pop(task_id, None) | ||
| for tid, msg in items(latest_by_id): | ||
| self.on_out_of_band_result(tid, msg) | ||
| if latest: | ||
| latest.requeue() | ||
| return self._set_cache_by_message(task_id, latest) | ||
| else: | ||
| # no new state, use previous | ||
| try: | ||
| return self._cache[task_id] | ||
| except KeyError: | ||
| # result probably pending. | ||
| return {'status': states.PENDING, 'result': None} | ||
| poll = get_task_meta # XXX compat | ||
| def _set_cache_by_message(self, task_id, message): | ||
| payload = self._cache[task_id] = self.meta_from_decoded( | ||
| message.payload) | ||
| return payload | ||
| def _slurp_from_queue(self, task_id, accept, | ||
| limit=1000, no_ack=False): | ||
| with self.app.pool.acquire_channel(block=True) as (_, channel): | ||
| binding = self._create_binding(task_id)(channel) | ||
| binding.declare() | ||
| for _ in range(limit): | ||
| msg = binding.get(accept=accept, no_ack=no_ack) | ||
| if not msg: | ||
| break | ||
| yield msg | ||
| else: | ||
| raise self.BacklogLimitExceeded(task_id) | ||
| def _get_message_task_id(self, message): | ||
| try: | ||
| # try property first so we don't have to deserialize | ||
| # the payload. | ||
| return message.properties['correlation_id'] | ||
| except (AttributeError, KeyError): | ||
| # message sent by old Celery version, need to deserialize. | ||
| return message.payload['task_id'] | ||
| def revive(self, channel): | ||
| pass | ||
| def reload_task_result(self, task_id): | ||
| raise NotImplementedError( | ||
| 'reload_task_result is not supported by this backend.') | ||
| def reload_group_result(self, task_id): | ||
| """Reload group result, even if it has been previously fetched.""" | ||
| raise NotImplementedError( | ||
| 'reload_group_result is not supported by this backend.') | ||
| def save_group(self, group_id, result): | ||
| raise NotImplementedError( | ||
| 'save_group is not supported by this backend.') | ||
| def restore_group(self, group_id, cache=True): | ||
| raise NotImplementedError( | ||
| 'restore_group is not supported by this backend.') | ||
| def delete_group(self, group_id): | ||
| raise NotImplementedError( | ||
| 'delete_group is not supported by this backend.') | ||
| def __reduce__(self, args=(), kwargs={}): | ||
| return super(RPCBackend, self).__reduce__(args, dict( | ||
| kwargs, | ||
| connection=self._connection, | ||
| exchange=self.exchange.name, | ||
| exchange_type=self.exchange.type, | ||
| persistent=self.persistent, | ||
| serializer=self.serializer, | ||
| auto_delete=self.auto_delete, | ||
| expires=self.expires, | ||
| )) | ||
| @property | ||
| def binding(self): | ||
| return self.Queue(self.oid, self.exchange, self.oid, | ||
| durable=False, auto_delete=False) | ||
| return self.Queue( | ||
| self.oid, self.exchange, self.oid, | ||
| durable=False, | ||
| auto_delete=True, | ||
| expires=self.expires, | ||
| ) | ||
| @cached_property | ||
| def oid(self): | ||
| # cached here is the app OID: name of queue we receive results on. | ||
| return self.app.oid |