Skip to content

Commit

Permalink
Propagated and restored contextvars in async_to_sync (#171)
Browse files Browse the repository at this point in the history
sync_to_async already makes sure that contextvars are properly
propagated, but async_to_sync does not.

In Django 3.1, this fixes the behavior of got_request_exception which
would otherwise be called in a different context than before_request or
after_request under the following circumstances:

1. An ASGI request has been made
2. There is at least one sync middleware
  • Loading branch information
untitaker committed Jun 15, 2020
1 parent 4c725fa commit 3a7ba92
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 19 deletions.
57 changes: 40 additions & 17 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
contextvars = None


def _restore_context(context):
# Check for changes in contextvars, and set them to the current
# context for downstream consumers
for cvar in context:
try:
if cvar.get() != context.get(cvar):
cvar.set(context.get(cvar))
except LookupError:
cvar.set(context.get(cvar))


class AsyncToSync:
"""
Utility class which turns an awaitable that only works on the thread with
Expand Down Expand Up @@ -66,6 +77,10 @@ def __call__(self, *args, **kwargs):
"You cannot use AsyncToSync in the same thread as an async event loop - "
"just await the async function directly."
)

if contextvars is not None:
context = contextvars.copy_context()

# Make a future for the return information
call_result = Future()
# Get the source thread
Expand All @@ -83,16 +98,19 @@ def __call__(self, *args, **kwargs):
# main event loop's thread if it's there, otherwise make a new loop
# in this thread.
try:
awaitable = self.main_wrap(
args, kwargs, call_result, source_thread, sys.exc_info()
)

if contextvars is not None:
awaitable = self._awaitable_with_context(awaitable, context)

if not (self.main_event_loop and self.main_event_loop.is_running()):
# Make our own event loop - in a new thread - and run inside that.
loop = asyncio.new_event_loop()
loop_executor = ThreadPoolExecutor(max_workers=1)
loop_future = loop_executor.submit(
self._run_event_loop,
loop,
self.main_wrap(
args, kwargs, call_result, source_thread, sys.exc_info()
),
self._run_event_loop, loop, awaitable
)
if current_executor:
# Run the CurrentThreadExecutor until the future is done
Expand All @@ -102,10 +120,7 @@ def __call__(self, *args, **kwargs):
else:
# Call it inside the existing loop
self.main_event_loop.call_soon_threadsafe(
self.main_event_loop.create_task,
self.main_wrap(
args, kwargs, call_result, source_thread, sys.exc_info()
),
self.main_event_loop.create_task, awaitable
)
if current_executor:
# Run the CurrentThreadExecutor until the future is done
Expand All @@ -116,6 +131,9 @@ def __call__(self, *args, **kwargs):
del self.executors.current
if old_current_executor:
self.executors.current = old_current_executor
if contextvars is not None:
_restore_context(context)

# Wait for results from the future.
return call_result.result()

Expand Down Expand Up @@ -185,6 +203,18 @@ async def main_wrap(self, args, kwargs, call_result, source_thread, exc_info):
finally:
del self.launch_map[current_task]

@staticmethod
def _awaitable_with_context(awaitable, context):
gen = awaitable.__await__()

while True:
try:
chunk = context.run(next, gen)
except StopIteration:
break

yield chunk


class SyncToAsync:
"""
Expand Down Expand Up @@ -269,14 +299,7 @@ async def __call__(self, *args, **kwargs):
ret = await asyncio.wait_for(future, timeout=None)

if contextvars is not None:
# Check for changes in contextvars, and set them to the current
# context for downstream consumers
for cvar in context:
try:
if cvar.get() != context.get(cvar):
cvar.set(context.get(cvar))
except LookupError:
cvar.set(context.get(cvar))
_restore_context(context)

return ret

Expand Down
26 changes: 24 additions & 2 deletions tests/test_sync_contextvars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import time

import pytest

from asgiref.sync import sync_to_async
from asgiref.sync import async_to_sync, sync_to_async

contextvars = pytest.importorskip("contextvars")

Expand All @@ -27,5 +28,26 @@ def sync_function():
# Wrap it
foo.set("bar")
async_function = sync_to_async(sync_function)
await async_function()
assert await async_function() == 42
assert foo.get() == "baz"


def test_async_to_sync_contextvars():
"""
Tests to make sure that contextvars from the calling context are
present in the called context, and that any changes in the called context
are then propagated back to the calling context.
"""
# Define sync function
async def async_function():
await asyncio.sleep(1)
assert foo.get() == "bar"
foo.set("baz")
return 42

# Ensure outermost detection works
# Wrap it
foo.set("bar")
sync_function = async_to_sync(async_function)
assert sync_function() == 42
assert foo.get() == "baz"

0 comments on commit 3a7ba92

Please sign in to comment.