diff --git a/celery/backends/base.py b/celery/backends/base.py index 37d595315a1..1eef589c067 100644 --- a/celery/backends/base.py +++ b/celery/backends/base.py @@ -351,6 +351,54 @@ def encode_result(self, result, state): def is_cached(self, task_id): return task_id in self._cache + def _get_result_meta(self, result, + state, traceback, request, format_date=True, + encode=False): + if state in self.READY_STATES: + date_done = datetime.datetime.utcnow() + if format_date: + date_done = date_done.isoformat() + else: + date_done = None + + meta = { + 'status': state, + 'result': result, + 'traceback': traceback, + 'children': self.current_task_children(request), + 'date_done': date_done, + } + + if request and getattr(request, 'group', None): + meta['group_id'] = request.group + if request and getattr(request, 'parent_id', None): + meta['parent_id'] = request.parent_id + + if self.app.conf.find_value_for_key('extended', 'result'): + if request: + request_meta = { + 'name': getattr(request, 'task', None), + 'args': getattr(request, 'args', None), + 'kwargs': getattr(request, 'kwargs', None), + 'worker': getattr(request, 'hostname', None), + 'retries': getattr(request, 'retries', None), + 'queue': request.delivery_info.get('routing_key') + if hasattr(request, 'delivery_info') and + request.delivery_info else None + } + + if encode: + # args and kwargs need to be encoded properly before saving + encode_needed_fields = {"args", "kwargs"} + for field in encode_needed_fields: + value = request_meta[field] + encoded_value = self.encode(value) + request_meta[field] = ensure_bytes(encoded_value) + + meta.update(request_meta) + + return meta + def store_result(self, task_id, result, state, traceback=None, request=None, **kwargs): """Update task state and result.""" @@ -703,40 +751,9 @@ def _forget(self, task_id): def _store_result(self, task_id, result, state, traceback=None, request=None, **kwargs): - - if state in self.READY_STATES: - date_done = datetime.datetime.utcnow().isoformat() - else: - date_done = None - - meta = { - 'status': state, - 'result': result, - 'traceback': traceback, - 'children': self.current_task_children(request), - 'task_id': bytes_to_str(task_id), - 'date_done': date_done, - } - - if request and getattr(request, 'group', None): - meta['group_id'] = request.group - if request and getattr(request, 'parent_id', None): - meta['parent_id'] = request.parent_id - - if self.app.conf.find_value_for_key('extended', 'result'): - if request: - request_meta = { - 'name': getattr(request, 'task', None), - 'args': getattr(request, 'args', None), - 'kwargs': getattr(request, 'kwargs', None), - 'worker': getattr(request, 'hostname', None), - 'retries': getattr(request, 'retries', None), - 'queue': request.delivery_info.get('routing_key') - if hasattr(request, 'delivery_info') and - request.delivery_info else None - } - - meta.update(request_meta) + meta = self._get_result_meta(result=result, state=state, + traceback=traceback, request=request) + meta['task_id'] = bytes_to_str(task_id) self.set(self.get_key_for_task(task_id), self.encode(meta)) return result diff --git a/celery/backends/database/__init__.py b/celery/backends/database/__init__.py index a332a8137b5..7ee6f5f870b 100644 --- a/celery/backends/database/__init__.py +++ b/celery/backends/database/__init__.py @@ -5,7 +5,6 @@ import logging from contextlib import contextmanager -from kombu.utils.encoding import ensure_bytes from vine.utils import wraps from celery import states @@ -120,6 +119,7 @@ def _store_result(self, task_id, result, state, traceback=None, task = task and task[0] if not task: task = self.task_cls(task_id) + task.task_id = task_id session.add(task) session.flush() @@ -128,24 +128,22 @@ def _store_result(self, task_id, result, state, traceback=None, def _update_result(self, task, result, state, traceback=None, request=None): - task.result = result - task.status = state - task.traceback = traceback - if self.app.conf.find_value_for_key('extended', 'result'): - task.name = getattr(request, 'task', None) - task.args = ensure_bytes( - self.encode(getattr(request, 'args', None)) - ) - task.kwargs = ensure_bytes( - self.encode(getattr(request, 'kwargs', None)) - ) - task.worker = getattr(request, 'hostname', None) - task.retries = getattr(request, 'retries', None) - task.queue = ( - request.delivery_info.get("routing_key") - if hasattr(request, "delivery_info") and request.delivery_info - else None - ) + + meta = self._get_result_meta(result=result, state=state, + traceback=traceback, request=request, + format_date=False, encode=True) + + # Exclude the primary key id and task_id columns + # as we should not set it None + columns = [column.name for column in self.task_cls.__table__.columns + if column.name not in {'id', 'task_id'}] + + # Iterate through the columns name of the table + # to set the value from meta. + # If the value is not present in meta, set None + for column in columns: + value = meta.get(column) + setattr(task, column, value) @retry def _get_task_meta_for(self, task_id): diff --git a/celery/backends/mongodb.py b/celery/backends/mongodb.py index 8d551bca802..198f7881594 100644 --- a/celery/backends/mongodb.py +++ b/celery/backends/mongodb.py @@ -185,18 +185,10 @@ def decode(self, data): def _store_result(self, task_id, result, state, traceback=None, request=None, **kwargs): """Store return value and state of an executed task.""" - meta = { - '_id': task_id, - 'status': state, - 'result': self.encode(result), - 'date_done': datetime.utcnow(), - 'traceback': self.encode(traceback), - 'children': self.encode( - self.current_task_children(request), - ), - } - if request and getattr(request, 'parent_id', None): - meta['parent_id'] = request.parent_id + meta = self._get_result_meta(result=result, state=state, + traceback=traceback, request=request) + # Add the _id for mongodb + meta['_id'] = task_id try: self.collection.replace_one({'_id': task_id}, meta, upsert=True) diff --git a/t/unit/backends/test_base.py b/t/unit/backends/test_base.py index 6fbbd2d7d77..a458bc149c6 100644 --- a/t/unit/backends/test_base.py +++ b/t/unit/backends/test_base.py @@ -7,6 +7,7 @@ import pytest from case import ANY, Mock, call, patch, skip from kombu.serialization import prepare_accept_content +from kombu.utils.encoding import ensure_bytes import celery from celery import chord, group, signature, states, uuid @@ -104,6 +105,45 @@ def test_accept_precedence(self): assert list(b4.accept)[0] == 'application/x-yaml' assert prepare_accept_content(['yaml']) == b4.accept + def test_get_result_meta(self): + b1 = BaseBackend(self.app) + meta = b1._get_result_meta(result={'fizz': 'buzz'}, + state=states.SUCCESS, traceback=None, + request=None) + assert meta['status'] == states.SUCCESS + assert meta['result'] == {'fizz': 'buzz'} + assert meta['traceback'] is None + + self.app.conf.result_extended = True + args = ['a', 'b'] + kwargs = {'foo': 'bar'} + task_name = 'mytask' + + b2 = BaseBackend(self.app) + request = Context(args=args, kwargs=kwargs, + task=task_name, + delivery_info={'routing_key': 'celery'}) + meta = b2._get_result_meta(result={'fizz': 'buzz'}, + state=states.SUCCESS, traceback=None, + request=request, encode=False) + assert meta['name'] == task_name + assert meta['args'] == args + assert meta['kwargs'] == kwargs + assert meta['queue'] == 'celery' + + def test_get_result_meta_encoded(self): + self.app.conf.result_extended = True + b1 = BaseBackend(self.app) + args = ['a', 'b'] + kwargs = {'foo': 'bar'} + + request = Context(args=args, kwargs=kwargs) + meta = b1._get_result_meta(result={'fizz': 'buzz'}, + state=states.SUCCESS, traceback=None, + request=request, encode=True) + assert meta['args'] == ensure_bytes(b1.encode(args)) + assert meta['kwargs'] == ensure_bytes(b1.encode(kwargs)) + class test_BaseBackend_interface: diff --git a/t/unit/backends/test_database.py b/t/unit/backends/test_database.py index d3dcdc9173f..4a2dd1734c5 100644 --- a/t/unit/backends/test_database.py +++ b/t/unit/backends/test_database.py @@ -246,6 +246,37 @@ def test_store_result(self, result_serializer, args, kwargs): assert meta['retries'] == 2 assert meta['worker'] == "celery@worker_1" + @pytest.mark.parametrize( + 'result_serializer, args, kwargs', + [ + ('pickle', (SomeClass(1), SomeClass(2)), + {'foo': SomeClass(123)}), + ('json', ['a', 'b'], {'foo': 'bar'}), + ], + ids=['using pickle', 'using json'] + ) + def test_get_result_meta(self, result_serializer, args, kwargs): + self.app.conf.result_serializer = result_serializer + tb = DatabaseBackend(self.uri, app=self.app) + + request = Context(args=args, kwargs=kwargs, + task='mytask', retries=2, + hostname='celery@worker_1', + delivery_info={'routing_key': 'celery'}) + + meta = tb._get_result_meta(result={'fizz': 'buzz'}, + state=states.SUCCESS, traceback=None, + request=request, format_date=False, + encode=True) + + assert meta['result'] == {'fizz': 'buzz'} + assert tb.decode(meta['args']) == args + assert tb.decode(meta['kwargs']) == kwargs + assert meta['queue'] == 'celery' + assert meta['name'] == 'mytask' + assert meta['retries'] == 2 + assert meta['worker'] == "celery@worker_1" + @skip.unless_module('sqlalchemy') class test_SessionManager: