Skip to content

Commit

Permalink
Rewrote how context is propagated in AsyncToSync (#175)
Browse files Browse the repository at this point in the history
This avoids exception-related errors in downstream users (e.g. Django)
  • Loading branch information
untitaker committed Jun 16, 2020
1 parent fd58c47 commit 256a1fd
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def __call__(self, *args, **kwargs):
)

if contextvars is not None:
context = contextvars.copy_context()
# Wrapping context in list so it can be reassigned from within
# `main_wrap`.
context = [contextvars.copy_context()]
else:
context = None

# Make a future for the return information
call_result = Future()
Expand All @@ -99,12 +103,9 @@ def __call__(self, *args, **kwargs):
# in this thread.
try:
awaitable = self.main_wrap(
args, kwargs, call_result, source_thread, sys.exc_info()
args, kwargs, call_result, source_thread, sys.exc_info(), context
)

if contextvars is not None:
awaitable = self._awaitable_with_context(awaitable, context)

if not (self.main_event_loop and self.main_event_loop.is_running()):
# Make our own event loop - in a new thread - and run inside that.
loop = asyncio.new_event_loop()
Expand Down Expand Up @@ -132,7 +133,7 @@ def __call__(self, *args, **kwargs):
if old_current_executor:
self.executors.current = old_current_executor
if contextvars is not None:
_restore_context(context)
_restore_context(context[0])

# Wait for results from the future.
return call_result.result()
Expand Down Expand Up @@ -179,11 +180,16 @@ def __get__(self, parent, objtype):
func = functools.partial(self.__call__, parent)
return functools.update_wrapper(func, self.awaitable)

async def main_wrap(self, args, kwargs, call_result, source_thread, exc_info):
async def main_wrap(
self, args, kwargs, call_result, source_thread, exc_info, context
):
"""
Wraps the awaitable with something that puts the result into the
result/exception future.
"""
if context is not None:
_restore_context(context[0])

current_task = SyncToAsync.get_current_task()
self.launch_map[current_task] = source_thread
try:
Expand All @@ -203,17 +209,8 @@ async def main_wrap(self, args, kwargs, call_result, source_thread, exc_info):
finally:
del self.launch_map[current_task]

@staticmethod
def _awaitable_with_context(awaitable, context):
gen = awaitable.__await__()

while True:
try:
chunk = context.run(next, gen)
except StopIteration:
break

yield chunk
if context is not None:
context[0] = contextvars.copy_context()


class SyncToAsync:
Expand Down

0 comments on commit 256a1fd

Please sign in to comment.