diff --git a/asgiref/sync.py b/asgiref/sync.py index 5406b7d3..692cbc75 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -203,6 +203,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # `main_wrap`. context = [contextvars.copy_context()] + # Get task context so that parent task knows which task to propagate + # an asyncio.CancelledError to. + task_context = getattr(SyncToAsync.threadlocal, "task_context", None) + loop = None # Use call_soon_threadsafe to schedule a synchronous callback on the # main event loop's thread if it's there, otherwise make a new loop @@ -211,6 +215,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: awaitable = self.main_wrap( call_result, sys.exc_info(), + task_context, context, *args, **kwargs, @@ -295,6 +300,7 @@ async def main_wrap( self, call_result: "Future[_R]", exc_info: "OptExcInfo", + task_context: "Optional[List[asyncio.Task[Any]]]", context: List[contextvars.Context], *args: _P.args, **kwargs: _P.kwargs, @@ -309,6 +315,10 @@ async def main_wrap( if context is not None: _restore_context(context[0]) + current_task = asyncio.current_task() + if current_task is not None and task_context is not None: + task_context.append(current_task) + try: # If we have an exception, run the function inside the except block # after raising it so exc_info is correctly populated. @@ -324,6 +334,8 @@ async def main_wrap( else: call_result.set_result(result) finally: + if current_task is not None and task_context is not None: + task_context.remove(current_task) context[0] = contextvars.copy_context() @@ -437,20 +449,38 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: context = contextvars.copy_context() child = functools.partial(self.func, *args, **kwargs) func = context.run - + task_context: List[asyncio.Task[Any]] = [] + + # Run the code in the right thread + exec_coro = loop.run_in_executor( + executor, + functools.partial( + self.thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ), + ) + ret: _R try: - # Run the code in the right thread - ret: _R = await loop.run_in_executor( - executor, - functools.partial( - self.thread_handler, - loop, - sys.exc_info(), - func, - child, - ), - ) - + ret = await asyncio.shield(exec_coro) + except asyncio.CancelledError: + cancel_parent = True + try: + task = task_context[0] + task.cancel() + try: + await task + cancel_parent = False + except asyncio.CancelledError: + pass + except IndexError: + pass + if cancel_parent: + exec_coro.cancel() + ret = await exec_coro finally: _restore_context(context) self.deadlock_context.set(False) @@ -466,7 +496,7 @@ def __get__( func = functools.partial(self.__call__, parent) return functools.update_wrapper(func, self.func) - def thread_handler(self, loop, exc_info, func, *args, **kwargs): + def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs): """ Wraps the sync application with exception handling. """ @@ -476,6 +506,7 @@ def thread_handler(self, loop, exc_info, func, *args, **kwargs): # Set the threadlocal for AsyncToSync self.threadlocal.main_event_loop = loop self.threadlocal.main_event_loop_pid = os.getpid() + self.threadlocal.task_context = task_context # Run the function # If we have an exception, run the function inside the except block diff --git a/tests/test_sync.py b/tests/test_sync.py index daed8c4d..3e83c91b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -852,13 +852,10 @@ def sync_task(): @pytest.mark.asyncio -@pytest.mark.skip(reason="deadlocks") async def test_inner_shield_sync_middleware(): """ Tests that asyncio.shield is capable of preventing http.disconnect from cancelling a django request task when using sync middleware. - - Currently this tests is skipped as it causes a deadlock. """ # Hypothetical Django scenario - middleware function is sync @@ -968,3 +965,159 @@ async def async_task(): assert task_complete assert task_executed + + +@pytest.mark.asyncio +async def test_inner_shield_sync_and_async_middleware(): + """ + Tests that asyncio.shield is capable of preventing http.disconnect from + cancelling a django request task when using sync and middleware chained + together. + """ + + # Hypothetical Django scenario - middleware function is sync + def sync_middleware_1(): + async_to_sync(async_middleware_2)() + + # Hypothetical Django scenario - middleware function is async + async def async_middleware_2(): + await sync_to_async(sync_middleware_3)() + + # Hypothetical Django scenario - middleware function is sync + def sync_middleware_3(): + async_to_sync(async_middleware_4)() + + # Hypothetical Django scenario - middleware function is async + async def async_middleware_4(): + await sync_to_async(sync_middleware_5)() + + # Hypothetical Django scenario - middleware function is sync + def sync_middleware_5(): + async_to_sync(async_view)() + + task_complete = False + task_cancel_caught = False + + # Future that completes when subtask cancellation attempt is caught + task_blocker = asyncio.Future() + + async def async_view(): + """Async view with a task that is shielded from cancellation.""" + nonlocal task_complete, task_cancel_caught, task_blocker + task = asyncio.create_task(async_task()) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + task_cancel_caught = True + task_blocker.set_result(True) + await task + task_complete = True + + task_executed = False + + # Future that completes after subtask is created + task_started_future = asyncio.Future() + + async def async_task(): + """Async subtask that should not be canceled when parent is canceled.""" + nonlocal task_started_future, task_executed, task_blocker + task_started_future.set_result(True) + await task_blocker + task_executed = True + + task_cancel_propagated = False + + async with ThreadSensitiveContext(): + task = asyncio.create_task(sync_to_async(sync_middleware_1)()) + await task_started_future + task.cancel() + try: + await task + except asyncio.CancelledError: + task_cancel_propagated = True + assert not task_cancel_propagated + assert task_cancel_caught + assert task_complete + + assert task_executed + + +@pytest.mark.asyncio +async def test_inner_shield_sync_and_async_middleware_sync_task(): + """ + Tests that asyncio.shield is capable of preventing http.disconnect from + cancelling a django request task when using sync and middleware chained + together with an async view calling a sync function calling an async task. + + This test ensures that a parent initiated task cancellation will not + propagate to a shielded subtask. + """ + + # Hypothetical Django scenario - middleware function is sync + def sync_middleware_1(): + async_to_sync(async_middleware_2)() + + # Hypothetical Django scenario - middleware function is async + async def async_middleware_2(): + await sync_to_async(sync_middleware_3)() + + # Hypothetical Django scenario - middleware function is sync + def sync_middleware_3(): + async_to_sync(async_middleware_4)() + + # Hypothetical Django scenario - middleware function is async + async def async_middleware_4(): + await sync_to_async(sync_middleware_5)() + + # Hypothetical Django scenario - middleware function is sync + def sync_middleware_5(): + async_to_sync(async_view)() + + task_complete = False + task_cancel_caught = False + + # Future that completes when subtask cancellation attempt is caught + task_blocker = asyncio.Future() + + async def async_view(): + """Async view with a task that is shielded from cancellation.""" + nonlocal task_complete, task_cancel_caught, task_blocker + task = asyncio.create_task(sync_to_async(sync_parent)()) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + task_cancel_caught = True + task_blocker.set_result(True) + await task + task_complete = True + + task_executed = False + + # Future that completes after subtask is created + task_started_future = asyncio.Future() + + def sync_parent(): + async_to_sync(async_task)() + + async def async_task(): + """Async subtask that should not be canceled when parent is canceled.""" + nonlocal task_started_future, task_executed, task_blocker + task_started_future.set_result(True) + await task_blocker + task_executed = True + + task_cancel_propagated = False + + async with ThreadSensitiveContext(): + task = asyncio.create_task(sync_to_async(sync_middleware_1)()) + await task_started_future + task.cancel() + try: + await task + except asyncio.CancelledError: + task_cancel_propagated = True + assert not task_cancel_propagated + assert task_cancel_caught + assert task_complete + + assert task_executed