From e2031688284484d5b5a57ba29cd9cae2d9a81e39 Mon Sep 17 00:00:00 2001 From: Matus Valo Date: Sun, 22 Nov 2020 16:59:14 +0100 Subject: [PATCH] Multithreaded backend (#6416) * Cache backend to thread local storage instead of global variable * Cache oid to thread local storage instead of global variable * Improve code returning thread_local data * Move thread local storage to Celery class, introduced thread_oid and added unittests --- celery/app/base.py | 24 +++++++++++++-- celery/backends/rpc.py | 4 +-- celery/canvas.py | 2 +- t/unit/app/test_app.py | 59 +++++++++++++++++++++++++++++++++++++ t/unit/backends/test_rpc.py | 17 ++++++++++- t/unit/tasks/test_chord.py | 7 ++--- t/unit/tasks/test_result.py | 47 ++++++++++++++--------------- t/unit/test_canvas.py | 33 +++++++++++++++++++++ 8 files changed, 159 insertions(+), 34 deletions(-) create mode 100644 t/unit/test_canvas.py diff --git a/celery/app/base.py b/celery/app/base.py index ed4bd748b56..27e5b610ca7 100644 --- a/celery/app/base.py +++ b/celery/app/base.py @@ -206,6 +206,8 @@ class name. task_cls = 'celery.app.task:Task' registry_cls = 'celery.app.registry:TaskRegistry' + #: Thread local storage. + _local = None _fixups = None _pool = None _conf = None @@ -229,6 +231,9 @@ def __init__(self, main=None, loader=None, backend=None, changes=None, config_source=None, fixups=None, task_cls=None, autofinalize=True, namespace=None, strict_typing=True, **kwargs): + + self._local = threading.local() + self.clock = LamportClock() self.main = main self.amqp_cls = amqp or self.amqp_cls @@ -727,7 +732,7 @@ def send_task(self, name, args=None, kwargs=None, countdown=None, task_id, name, args, kwargs, countdown, eta, group_id, group_index, expires, retries, chord, maybe_list(link), maybe_list(link_error), - reply_to or self.oid, time_limit, soft_time_limit, + reply_to or self.thread_oid, time_limit, soft_time_limit, self.conf.task_send_sent_event, root_id, parent_id, shadow, chain, argsrepr=options.get('argsrepr'), @@ -1185,15 +1190,28 @@ def oid(self): # which would not work if each thread has a separate id. return oid_from(self, threads=False) + @property + def thread_oid(self): + """Per-thread unique identifier for this app.""" + try: + return self._local.oid + except AttributeError: + self._local.oid = new_oid = oid_from(self, threads=True) + return new_oid + @cached_property def amqp(self): """AMQP related functionality: :class:`~@amqp`.""" return instantiate(self.amqp_cls, app=self) - @cached_property + @property def backend(self): """Current backend instance.""" - return self._get_backend() + try: + return self._local.backend + except AttributeError: + self._local.backend = new_backend = self._get_backend() + return new_backend @property def conf(self): diff --git a/celery/backends/rpc.py b/celery/backends/rpc.py index 9b851db4de8..399c1dc7a20 100644 --- a/celery/backends/rpc.py +++ b/celery/backends/rpc.py @@ -338,5 +338,5 @@ def binding(self): @cached_property def oid(self): - # cached here is the app OID: name of queue we receive results on. - return self.app.oid + # cached here is the app thread OID: name of queue we receive results on. + return self.app.thread_oid diff --git a/celery/canvas.py b/celery/canvas.py index 0279965d2ee..a4de76428dc 100644 --- a/celery/canvas.py +++ b/celery/canvas.py @@ -296,7 +296,7 @@ def freeze(self, _id=None, group_id=None, chord=None, if parent_id: opts['parent_id'] = parent_id if 'reply_to' not in opts: - opts['reply_to'] = self.app.oid + opts['reply_to'] = self.app.thread_oid if group_id and "group_id" not in opts: opts['group_id'] = group_id if chord: diff --git a/t/unit/app/test_app.py b/t/unit/app/test_app.py index a533d0cc4d4..2512b16cd4f 100644 --- a/t/unit/app/test_app.py +++ b/t/unit/app/test_app.py @@ -2,6 +2,7 @@ import itertools import os import ssl +import uuid from copy import deepcopy from datetime import datetime, timedelta from pickle import dumps, loads @@ -17,6 +18,7 @@ from celery.app import base as _appbase from celery.app import defaults from celery.exceptions import ImproperlyConfigured +from celery.backends.base import Backend from celery.loaders.base import unconfigured from celery.platforms import pyimplementation from celery.utils.collections import DictAttribute @@ -987,6 +989,63 @@ class CustomCelery(type(self.app)): app = CustomCelery(set_as_current=False) assert isinstance(app.tasks, TaskRegistry) + def test_oid(self): + # Test that oid is global value. + oid1 = self.app.oid + oid2 = self.app.oid + uuid.UUID(oid1) + uuid.UUID(oid2) + assert oid1 == oid2 + + def test_global_oid(self): + # Test that oid is global value also within threads + main_oid = self.app.oid + uuid.UUID(main_oid) + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(lambda: self.app.oid) + thread_oid = future.result() + uuid.UUID(thread_oid) + assert main_oid == thread_oid + + def test_thread_oid(self): + # Test that thread_oid is global value in single thread. + oid1 = self.app.thread_oid + oid2 = self.app.thread_oid + uuid.UUID(oid1) + uuid.UUID(oid2) + assert oid1 == oid2 + + def test_backend(self): + # Test that app.bakend returns the same backend in single thread + backend1 = self.app.backend + backend2 = self.app.backend + assert isinstance(backend1, Backend) + assert isinstance(backend2, Backend) + assert backend1 is backend2 + + def test_thread_backend(self): + # Test that app.bakend returns the new backend for each thread + main_backend = self.app.backend + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(lambda: self.app.backend) + thread_backend = future.result() + assert isinstance(main_backend, Backend) + assert isinstance(thread_backend, Backend) + assert main_backend is not thread_backend + + def test_thread_oid_is_local(self): + # Test that thread_oid is local to thread. + main_oid = self.app.thread_oid + uuid.UUID(main_oid) + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(lambda: self.app.thread_oid) + thread_oid = future.result() + uuid.UUID(thread_oid) + assert main_oid != thread_oid + class test_defaults: diff --git a/t/unit/backends/test_rpc.py b/t/unit/backends/test_rpc.py index f8567400706..71e573da8ff 100644 --- a/t/unit/backends/test_rpc.py +++ b/t/unit/backends/test_rpc.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import Mock, patch import pytest @@ -28,8 +29,22 @@ def setup(self): def test_oid(self): oid = self.b.oid oid2 = self.b.oid + assert uuid.UUID(oid) assert oid == oid2 - assert oid == self.app.oid + assert oid == self.app.thread_oid + + def test_oid_threads(self): + # Verify that two RPC backends executed in different threads + # has different oid. + oid = self.b.oid + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(lambda: RPCBackend(app=self.app).oid) + thread_oid = future.result() + assert uuid.UUID(oid) + assert uuid.UUID(thread_oid) + assert oid == self.app.thread_oid + assert thread_oid != oid def test_interface(self): self.b.on_reply_declare('task_id') diff --git a/t/unit/tasks/test_chord.py b/t/unit/tasks/test_chord.py index e25e2ccc229..bbec557831a 100644 --- a/t/unit/tasks/test_chord.py +++ b/t/unit/tasks/test_chord.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from unittest.mock import Mock, patch, sentinel +from unittest.mock import Mock, patch, sentinel, PropertyMock import pytest @@ -294,9 +294,8 @@ def adds(self, sig, lazy=False): return self.add_to_chord(sig, lazy) self.adds = adds + @patch('celery.Celery.backend', new=PropertyMock(name='backend')) def test_add_to_chord(self): - self.app.backend = Mock(name='backend') - sig = self.add.s(2, 2) sig.delay = Mock(name='sig.delay') self.adds.request.group = uuid() @@ -333,8 +332,8 @@ def test_add_to_chord(self): class test_Chord_task(ChordCase): + @patch('celery.Celery.backend', new=PropertyMock(name='backend')) def test_run(self): - self.app.backend = Mock() self.app.backend.cleanup = Mock() self.app.backend.cleanup.__name__ = 'cleanup' Chord = self.app.tasks['celery.chord'] diff --git a/t/unit/tasks/test_result.py b/t/unit/tasks/test_result.py index e3d06db0f30..d16dc9eae26 100644 --- a/t/unit/tasks/test_result.py +++ b/t/unit/tasks/test_result.py @@ -708,19 +708,19 @@ def test_get_nested_without_native_join(self): ]), ]), ]) - ts.app.backend = backend - vals = ts.get() - assert vals == [ - '1.1', - [ - '2.1', + with patch('celery.Celery.backend', new=backend): + vals = ts.get() + assert vals == [ + '1.1', [ - '3.1', - '3.2', - ] - ], - ] + '2.1', + [ + '3.1', + '3.2', + ] + ], + ] def test_getitem(self): subs = [MockAsyncResultSuccess(uuid(), app=self.app), @@ -771,15 +771,16 @@ def test_join_native(self): results = [self.app.AsyncResult(uuid(), backend=backend) for i in range(10)] ts = self.app.GroupResult(uuid(), results) - ts.app.backend = backend - backend.ids = [result.id for result in results] - res = ts.join_native() - assert res == list(range(10)) - callback = Mock(name='callback') - assert not ts.join_native(callback=callback) - callback.assert_has_calls([ - call(r.id, i) for i, r in enumerate(ts.results) - ]) + + with patch('celery.Celery.backend', new=backend): + backend.ids = [result.id for result in results] + res = ts.join_native() + assert res == list(range(10)) + callback = Mock(name='callback') + assert not ts.join_native(callback=callback) + callback.assert_has_calls([ + call(r.id, i) for i, r in enumerate(ts.results) + ]) def test_join_native_raises(self): ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())]) @@ -813,9 +814,9 @@ def test_iter_native(self): results = [self.app.AsyncResult(uuid(), backend=backend) for i in range(10)] ts = self.app.GroupResult(uuid(), results) - ts.app.backend = backend - backend.ids = [result.id for result in results] - assert len(list(ts.iter_native())) == 10 + with patch('celery.Celery.backend', new=backend): + backend.ids = [result.id for result in results] + assert len(list(ts.iter_native())) == 10 def test_join_timeout(self): ar = MockAsyncResultSuccess(uuid(), app=self.app) diff --git a/t/unit/test_canvas.py b/t/unit/test_canvas.py new file mode 100644 index 00000000000..4ba7ba59f3e --- /dev/null +++ b/t/unit/test_canvas.py @@ -0,0 +1,33 @@ +import uuid + + +class test_Canvas: + + def test_freeze_reply_to(self): + # Tests that Canvas.freeze() correctly + # creates reply_to option + + @self.app.task + def test_task(a, b): + return + + s = test_task.s(2, 2) + s.freeze() + + from concurrent.futures import ThreadPoolExecutor + + def foo(): + s = test_task.s(2, 2) + s.freeze() + return self.app.thread_oid, s.options['reply_to'] + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(foo) + t_reply_to_app, t_reply_to_opt = future.result() + + assert uuid.UUID(s.options['reply_to']) + assert uuid.UUID(t_reply_to_opt) + # reply_to must be equal to thread_oid of Application + assert self.app.thread_oid == s.options['reply_to'] + assert t_reply_to_app == t_reply_to_opt + # reply_to must be thread-relative. + assert t_reply_to_opt != s.options['reply_to']