Skip to content

Commit

Permalink
Multithreaded backend (#6416)
Browse files Browse the repository at this point in the history
* Cache backend to thread local storage instead of global variable

* Cache oid to thread local storage instead of global variable

* Improve code returning thread_local data

* Move thread local storage to Celery class, introduced thread_oid and added unittests
  • Loading branch information
matusvalo committed Nov 22, 2020
1 parent 60ba379 commit e203168
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 34 deletions.
24 changes: 21 additions & 3 deletions celery/app/base.py
Expand Up @@ -206,6 +206,8 @@ class name.
task_cls = 'celery.app.task:Task'
registry_cls = 'celery.app.registry:TaskRegistry'

#: Thread local storage.
_local = None
_fixups = None
_pool = None
_conf = None
Expand All @@ -229,6 +231,9 @@ def __init__(self, main=None, loader=None, backend=None,
changes=None, config_source=None, fixups=None, task_cls=None,
autofinalize=True, namespace=None, strict_typing=True,
**kwargs):

self._local = threading.local()

self.clock = LamportClock()
self.main = main
self.amqp_cls = amqp or self.amqp_cls
Expand Down Expand Up @@ -727,7 +732,7 @@ def send_task(self, name, args=None, kwargs=None, countdown=None,
task_id, name, args, kwargs, countdown, eta, group_id, group_index,
expires, retries, chord,
maybe_list(link), maybe_list(link_error),
reply_to or self.oid, time_limit, soft_time_limit,
reply_to or self.thread_oid, time_limit, soft_time_limit,
self.conf.task_send_sent_event,
root_id, parent_id, shadow, chain,
argsrepr=options.get('argsrepr'),
Expand Down Expand Up @@ -1185,15 +1190,28 @@ def oid(self):
# which would not work if each thread has a separate id.
return oid_from(self, threads=False)

@property
def thread_oid(self):
"""Per-thread unique identifier for this app."""
try:
return self._local.oid
except AttributeError:
self._local.oid = new_oid = oid_from(self, threads=True)
return new_oid

@cached_property
def amqp(self):
"""AMQP related functionality: :class:`~@amqp`."""
return instantiate(self.amqp_cls, app=self)

@cached_property
@property
def backend(self):
"""Current backend instance."""
return self._get_backend()
try:
return self._local.backend
except AttributeError:
self._local.backend = new_backend = self._get_backend()
return new_backend

@property
def conf(self):
Expand Down
4 changes: 2 additions & 2 deletions celery/backends/rpc.py
Expand Up @@ -338,5 +338,5 @@ def binding(self):

@cached_property
def oid(self):
# cached here is the app OID: name of queue we receive results on.
return self.app.oid
# cached here is the app thread OID: name of queue we receive results on.
return self.app.thread_oid
2 changes: 1 addition & 1 deletion celery/canvas.py
Expand Up @@ -296,7 +296,7 @@ def freeze(self, _id=None, group_id=None, chord=None,
if parent_id:
opts['parent_id'] = parent_id
if 'reply_to' not in opts:
opts['reply_to'] = self.app.oid
opts['reply_to'] = self.app.thread_oid
if group_id and "group_id" not in opts:
opts['group_id'] = group_id
if chord:
Expand Down
59 changes: 59 additions & 0 deletions t/unit/app/test_app.py
Expand Up @@ -2,6 +2,7 @@
import itertools
import os
import ssl
import uuid
from copy import deepcopy
from datetime import datetime, timedelta
from pickle import dumps, loads
Expand All @@ -17,6 +18,7 @@
from celery.app import base as _appbase
from celery.app import defaults
from celery.exceptions import ImproperlyConfigured
from celery.backends.base import Backend
from celery.loaders.base import unconfigured
from celery.platforms import pyimplementation
from celery.utils.collections import DictAttribute
Expand Down Expand Up @@ -987,6 +989,63 @@ class CustomCelery(type(self.app)):
app = CustomCelery(set_as_current=False)
assert isinstance(app.tasks, TaskRegistry)

def test_oid(self):
# Test that oid is global value.
oid1 = self.app.oid
oid2 = self.app.oid
uuid.UUID(oid1)
uuid.UUID(oid2)
assert oid1 == oid2

def test_global_oid(self):
# Test that oid is global value also within threads
main_oid = self.app.oid
uuid.UUID(main_oid)
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: self.app.oid)
thread_oid = future.result()
uuid.UUID(thread_oid)
assert main_oid == thread_oid

def test_thread_oid(self):
# Test that thread_oid is global value in single thread.
oid1 = self.app.thread_oid
oid2 = self.app.thread_oid
uuid.UUID(oid1)
uuid.UUID(oid2)
assert oid1 == oid2

def test_backend(self):
# Test that app.bakend returns the same backend in single thread
backend1 = self.app.backend
backend2 = self.app.backend
assert isinstance(backend1, Backend)
assert isinstance(backend2, Backend)
assert backend1 is backend2

def test_thread_backend(self):
# Test that app.bakend returns the new backend for each thread
main_backend = self.app.backend
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: self.app.backend)
thread_backend = future.result()
assert isinstance(main_backend, Backend)
assert isinstance(thread_backend, Backend)
assert main_backend is not thread_backend

def test_thread_oid_is_local(self):
# Test that thread_oid is local to thread.
main_oid = self.app.thread_oid
uuid.UUID(main_oid)
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: self.app.thread_oid)
thread_oid = future.result()
uuid.UUID(thread_oid)
assert main_oid != thread_oid


class test_defaults:

Expand Down
17 changes: 16 additions & 1 deletion t/unit/backends/test_rpc.py
@@ -1,3 +1,4 @@
import uuid
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -28,8 +29,22 @@ def setup(self):
def test_oid(self):
oid = self.b.oid
oid2 = self.b.oid
assert uuid.UUID(oid)
assert oid == oid2
assert oid == self.app.oid
assert oid == self.app.thread_oid

def test_oid_threads(self):
# Verify that two RPC backends executed in different threads
# has different oid.
oid = self.b.oid
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: RPCBackend(app=self.app).oid)
thread_oid = future.result()
assert uuid.UUID(oid)
assert uuid.UUID(thread_oid)
assert oid == self.app.thread_oid
assert thread_oid != oid

def test_interface(self):
self.b.on_reply_declare('task_id')
Expand Down
7 changes: 3 additions & 4 deletions t/unit/tasks/test_chord.py
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from unittest.mock import Mock, patch, sentinel
from unittest.mock import Mock, patch, sentinel, PropertyMock

import pytest

Expand Down Expand Up @@ -294,9 +294,8 @@ def adds(self, sig, lazy=False):
return self.add_to_chord(sig, lazy)
self.adds = adds

@patch('celery.Celery.backend', new=PropertyMock(name='backend'))
def test_add_to_chord(self):
self.app.backend = Mock(name='backend')

sig = self.add.s(2, 2)
sig.delay = Mock(name='sig.delay')
self.adds.request.group = uuid()
Expand Down Expand Up @@ -333,8 +332,8 @@ def test_add_to_chord(self):

class test_Chord_task(ChordCase):

@patch('celery.Celery.backend', new=PropertyMock(name='backend'))
def test_run(self):
self.app.backend = Mock()
self.app.backend.cleanup = Mock()
self.app.backend.cleanup.__name__ = 'cleanup'
Chord = self.app.tasks['celery.chord']
Expand Down
47 changes: 24 additions & 23 deletions t/unit/tasks/test_result.py
Expand Up @@ -708,19 +708,19 @@ def test_get_nested_without_native_join(self):
]),
]),
])
ts.app.backend = backend

vals = ts.get()
assert vals == [
'1.1',
[
'2.1',
with patch('celery.Celery.backend', new=backend):
vals = ts.get()
assert vals == [
'1.1',
[
'3.1',
'3.2',
]
],
]
'2.1',
[
'3.1',
'3.2',
]
],
]

def test_getitem(self):
subs = [MockAsyncResultSuccess(uuid(), app=self.app),
Expand Down Expand Up @@ -771,15 +771,16 @@ def test_join_native(self):
results = [self.app.AsyncResult(uuid(), backend=backend)
for i in range(10)]
ts = self.app.GroupResult(uuid(), results)
ts.app.backend = backend
backend.ids = [result.id for result in results]
res = ts.join_native()
assert res == list(range(10))
callback = Mock(name='callback')
assert not ts.join_native(callback=callback)
callback.assert_has_calls([
call(r.id, i) for i, r in enumerate(ts.results)
])

with patch('celery.Celery.backend', new=backend):
backend.ids = [result.id for result in results]
res = ts.join_native()
assert res == list(range(10))
callback = Mock(name='callback')
assert not ts.join_native(callback=callback)
callback.assert_has_calls([
call(r.id, i) for i, r in enumerate(ts.results)
])

def test_join_native_raises(self):
ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
Expand Down Expand Up @@ -813,9 +814,9 @@ def test_iter_native(self):
results = [self.app.AsyncResult(uuid(), backend=backend)
for i in range(10)]
ts = self.app.GroupResult(uuid(), results)
ts.app.backend = backend
backend.ids = [result.id for result in results]
assert len(list(ts.iter_native())) == 10
with patch('celery.Celery.backend', new=backend):
backend.ids = [result.id for result in results]
assert len(list(ts.iter_native())) == 10

def test_join_timeout(self):
ar = MockAsyncResultSuccess(uuid(), app=self.app)
Expand Down
33 changes: 33 additions & 0 deletions t/unit/test_canvas.py
@@ -0,0 +1,33 @@
import uuid


class test_Canvas:

def test_freeze_reply_to(self):
# Tests that Canvas.freeze() correctly
# creates reply_to option

@self.app.task
def test_task(a, b):
return

s = test_task.s(2, 2)
s.freeze()

from concurrent.futures import ThreadPoolExecutor

def foo():
s = test_task.s(2, 2)
s.freeze()
return self.app.thread_oid, s.options['reply_to']
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(foo)
t_reply_to_app, t_reply_to_opt = future.result()

assert uuid.UUID(s.options['reply_to'])
assert uuid.UUID(t_reply_to_opt)
# reply_to must be equal to thread_oid of Application
assert self.app.thread_oid == s.options['reply_to']
assert t_reply_to_app == t_reply_to_opt
# reply_to must be thread-relative.
assert t_reply_to_opt != s.options['reply_to']

0 comments on commit e203168

Please sign in to comment.