Skip to content

Commit

Permalink
Internalized thread sensitive context management in asgiref/sync.py
Browse files Browse the repository at this point in the history
Rather than using an external context management function, this commit
updates the thread_sensitive mode to reference an internal contextvar.

This approach mainly allows for the thread executor to be shut down when
the context exits, rather than relying on the thread pool shutting down
when going out of scope.

This approach also has other advantages:

* Context variable management cannot be broken by omitting the
  current_context_func arg

* This approach opens the door to more complex management techniques
  such as enforcing a thread pool limit.
  • Loading branch information
a-feld committed Feb 1, 2021
1 parent 96cbb8c commit 2b6174c
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 51 deletions.
11 changes: 3 additions & 8 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,9 @@ The first compromise you get to might be that ``thread_sensitive`` code should
just run in the same thread and not spawn in a sub-thread, fulfilling the first
restriction, but that immediately runs you into the second restriction.

By default, a variant of ThreadPoolExecutor executes any ``thread_sensitive``
code on the outermost synchronous thread - either the main thread, or a single
spawned subthread.

There is an option to override the default single-threaded behavior so that
there is 1 synchronous thread per context. If ``current_context_func`` is
specified, this function will be called to retrieve the current context. There
will be exactly 1 synchronous thread per context in this case.
The only real solution is to essentially have a variant of ThreadPoolExecutor
that executes any ``thread_sensitive`` code on the outermost synchronous
thread - either the main thread, or a single spawned subthread.

This means you now have two basic states:

Expand Down
84 changes: 68 additions & 16 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,59 @@ def _restore_context(context):
cvar.set(context.get(cvar))


class ThreadSensitiveContext:
"""Async context manager to manage context for thread sensitive mode
This context manager controls which thread pool executor is used when in
thread sensitive mode. By default, a single thread pool executor is shared
within a process.
In Python 3.7+, the ThreadSensitiveContext() context manager may be used to
specify a thread pool per context.
In Python 3.6, usage of this context manager has no effect.
This context manager is re-entrant, so only the outer-most call to
ThreadSensitiveContext will set the context.
Usage:
>>> import time
>>> async with ThreadSensitiveContext():
... await sync_to_async(time.sleep, 1)()
"""

def __init__(self):
self.token = None

if contextvars:

async def __aenter__(self):
try:
SyncToAsync.thread_sensitive_context.get()
except LookupError:
self.token = SyncToAsync.thread_sensitive_context.set(self)

return self

async def __aexit__(self, exc, value, tb):
if not self.token:
return

executor = SyncToAsync.context_to_thread_executor.pop(self, None)
if executor:
executor.shutdown()
SyncToAsync.thread_sensitive_context.reset(self.token)

else:

async def __aenter__(self):
return self

async def __aexit__(self, exc, value, tb):
pass


class AsyncToSync:
"""
Utility class which turns an awaitable that only works on the thread with
Expand Down Expand Up @@ -232,12 +285,6 @@ class SyncToAsync:
outer code. This is needed for underlying Python code that is not
threadsafe (for example, code which handles SQLite database connections).
If current_context_func is passed, the code will run 1 thread per context.
As an example, this may be used to create a per-request synchronous thread
by specifying the request object as the context. Thread scheduling will
occur by request in this scenario - each request will execute synchronous
work within the same thread.
If the outermost program is async (i.e. SyncToAsync is outermost), then
this will be a dedicated single sub-thread that all sync code runs in,
one after the other. If the outermost program is sync (i.e. AsyncToSync is
Expand All @@ -262,16 +309,21 @@ class SyncToAsync:
# Single-thread executor for thread-sensitive code
single_thread_executor = ThreadPoolExecutor(max_workers=1)

# Maintain a contextvar for the current execution context. Optionally used
# for thread sensitive mode.
thread_sensitive_context = (
contextvars.ContextVar("thread_sensitive_context") if contextvars else None
)

# Maintaining a weak reference to the context ensures that thread pools are
# erased once the context goes out of scope. This terminates the thread pool.
context_to_thread_executor = weakref.WeakKeyDictionary()

def __init__(self, func, thread_sensitive=True, current_context_func=None):
def __init__(self, func, thread_sensitive=True):
self.func = func
functools.update_wrapper(self, func)
self._thread_sensitive = thread_sensitive
self._is_coroutine = asyncio.coroutines._is_coroutine
self._current_context_func = current_context_func
try:
self.__self__ = func.__self__
except AttributeError:
Expand All @@ -285,18 +337,20 @@ async def __call__(self, *args, **kwargs):
if hasattr(AsyncToSync.executors, "current"):
# If we have a parent sync thread above somewhere, use that
executor = AsyncToSync.executors.current
elif self._current_context_func:
elif self.thread_sensitive_context and self.thread_sensitive_context.get(
None
):
# If we have a way of retrieving the current context, attempt
# to use a per-context thread pool executor
current_context = self._current_context_func()
thread_sensitive_context = self.thread_sensitive_context.get()

if current_context in self.context_to_thread_executor:
if thread_sensitive_context in self.context_to_thread_executor:
# Re-use thread executor in current context
executor = self.context_to_thread_executor[current_context]
executor = self.context_to_thread_executor[thread_sensitive_context]
else:
# Create new thread executor in current context
executor = ThreadPoolExecutor(max_workers=1)
self.context_to_thread_executor[current_context] = executor
self.context_to_thread_executor[thread_sensitive_context] = executor
else:
# Otherwise, we run it in a fixed single thread
executor = self.single_thread_executor
Expand Down Expand Up @@ -393,15 +447,13 @@ def get_current_task():
async_to_sync = AsyncToSync


def sync_to_async(func=None, thread_sensitive=True, current_context_func=None):
def sync_to_async(func=None, thread_sensitive=True):
if func is None:
return lambda f: SyncToAsync(
f,
thread_sensitive=thread_sensitive,
current_context_func=current_context_func,
)
return SyncToAsync(
func,
thread_sensitive=thread_sensitive,
current_context_func=current_context_func,
)
52 changes: 26 additions & 26 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from asgiref.sync import async_to_sync, sync_to_async
from asgiref.sync import ThreadSensitiveContext, async_to_sync, sync_to_async


@pytest.mark.asyncio
Expand Down Expand Up @@ -323,54 +323,54 @@ def inner(result):


@pytest.mark.asyncio
async def test_thread_sensitive_with_context_different():
async def test_thread_sensitive_with_context_matches():
result_1 = {}
result_2 = {}

class Context:
pass

context_1 = Context()
context_2 = Context()

def store_thread(result):
result["thread"] = threading.current_thread()

fn_1 = sync_to_async(store_thread, current_context_func=lambda: context_1)
fn_2 = sync_to_async(store_thread, current_context_func=lambda: context_2)
store_thread_async = sync_to_async(store_thread)

async def fn():
async with ThreadSensitiveContext():
# Run it (in supposed parallel!)
await asyncio.wait(
[store_thread_async(result_1), store_thread_async(result_2)]
)

# Run it (in true parallel!)
await asyncio.wait([fn_1(result_1), fn_2(result_2)])
await fn()

# They should not have run in the main thread, and on different threads
# They should not have run in the main thread, and on the same threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] != result_2["thread"]
assert result_1["thread"] == result_2["thread"]


@pytest.mark.asyncio
async def test_thread_sensitive_with_context_matches():
async def test_thread_sensitive_nested_context():
result_1 = {}
result_2 = {}

class Context:
pass

context = Context()

@sync_to_async
def store_thread(result):
result["thread"] = threading.current_thread()

fn_1 = sync_to_async(store_thread, current_context_func=lambda: context)
fn_2 = sync_to_async(store_thread, current_context_func=lambda: context)

# Run it (in supposed parallel!)
await asyncio.wait([fn_1(result_1), fn_2(result_2)])
async with ThreadSensitiveContext():
await store_thread(result_1)
async with ThreadSensitiveContext():
await store_thread(result_2)

# They should not have run in the main thread, and on different threads
# They should not have run in the main thread, and on the same threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] == result_2["thread"]


@pytest.mark.asyncio
async def test_thread_sensitive_context_without_sync_work():
async with ThreadSensitiveContext():
pass


def test_thread_sensitive_double_nested_sync():
"""
Tests that thread_sensitive SyncToAsync nests inside itself where the
Expand Down
24 changes: 23 additions & 1 deletion tests/test_sync_contextvars.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
import asyncio
import threading
import time

import pytest

from asgiref.sync import async_to_sync, sync_to_async
from asgiref.sync import ThreadSensitiveContext, async_to_sync, sync_to_async

contextvars = pytest.importorskip("contextvars")

foo = contextvars.ContextVar("foo")


@pytest.mark.asyncio
async def test_thread_sensitive_with_context_different():
result_1 = {}
result_2 = {}

@sync_to_async
def store_thread(result):
result["thread"] = threading.current_thread()

async def fn(result):
async with ThreadSensitiveContext():
await store_thread(result)

# Run it (in true parallel!)
await asyncio.wait([fn(result_1), fn(result_2)])

# They should not have run in the main thread, and on different threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] != result_2["thread"]


@pytest.mark.asyncio
async def test_sync_to_async_contextvars():
"""
Expand Down

0 comments on commit 2b6174c

Please sign in to comment.