From fd724e0bbf024e553884daaf97763c072fa980b8 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Fri, 29 Sep 2023 02:33:07 +0200 Subject: [PATCH] Updated pytest plugin to run all async/sync fixture/test code on a single Context Fixes #614 --- src/anyio/pytest_plugin.py | 65 ++++++++++++++++++++++++++++++++++--- tests/test_pytest_plugin.py | 59 +++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 4 deletions(-) diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 762e9e83..63873a75 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -2,16 +2,52 @@ from collections.abc import Iterator from contextlib import contextmanager +from contextvars import Context, copy_context from inspect import isasyncgenfunction, iscoroutinefunction from typing import Any, Dict, Tuple, cast import pytest import sniffio +from _pytest.stash import StashKey from ._core._eventloop import get_all_backends, get_async_backend +from ._run_in_context import ContextLike, context_wrap, context_wrap_async from .abc import TestRunner _current_runner: TestRunner | None = None +contextvars_context_key = StashKey[Context]() +_test_context_like_key = StashKey[ContextLike]() + + +class _TestContext(ContextLike): + """This class manages transmission of sniffio.current_async_library_cvar""" + + def __init__(self, context: Context): + self._context = context + + def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any: + return self._context.run( + self._set_context_and_run, + sniffio.current_async_library_cvar.get(None), + func, + *args, + **kwargs, + ) + + def _set_context_and_run( + self, current_async_library: str | None, func: Any, /, *args: Any, **kwargs: Any + ) -> Any: + reset_sniffio = None + if current_async_library is not None: + reset_sniffio = sniffio.current_async_library_cvar.set( + current_async_library + ) + + try: + return func(*args, **kwargs) + finally: + if reset_sniffio is not None: + sniffio.current_async_library_cvar.reset(reset_sniffio) def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: @@ -59,17 +95,26 @@ def pytest_configure(config: Any) -> None: ) +def pytest_sessionstart(session: pytest.Session) -> None: + context = copy_context() + session.stash[contextvars_context_key] = context + session.stash[_test_context_like_key] = _TestContext(context) + + def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: + context_like: ContextLike = request.session.stash[_test_context_like_key] + def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] backend_name, backend_options = extract_backend_and_options(anyio_backend) if has_backend_arg: kwargs["anyio_backend"] = anyio_backend with get_runner(backend_name, backend_options) as runner: + context_wrapped = context_wrap_async(context_like, func) if isasyncgenfunction(func): - yield from runner.run_asyncgen_fixture(func, kwargs) + yield from runner.run_asyncgen_fixture(context_wrapped, kwargs) else: - yield runner.run_fixture(func, kwargs) + yield runner.run_fixture(context_wrapped, kwargs) # Only apply this to coroutine functions and async generator functions in requests # that involve the anyio_backend fixture @@ -77,9 +122,14 @@ def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] if isasyncgenfunction(func) or iscoroutinefunction(func): if "anyio_backend" in request.fixturenames: has_backend_arg = "anyio_backend" in fixturedef.argnames + setattr(wrapper, "_runs_in_session_context", True) fixturedef.func = wrapper if not has_backend_arg: fixturedef.argnames += ("anyio_backend",) + elif not getattr(func, "_runs_in_session_context", False): + wrapper = context_wrap(context_like, func) + setattr(wrapper, "_runs_in_session_context", True) + fixturedef.func = wrapper @pytest.hookimpl(tryfirst=True) @@ -95,9 +145,12 @@ def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: @pytest.hookimpl(tryfirst=True) def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None: + context_like: ContextLike = pyfuncitem.session.stash[_test_context_like_key] + def run_with_hypothesis(**kwargs: Any) -> None: with get_runner(backend_name, backend_options) as runner: - runner.run_test(original_func, kwargs) + context_wrapped = context_wrap_async(context_like, original_func) + runner.run_test(context_wrapped, kwargs) backend = pyfuncitem.funcargs.get("anyio_backend") if backend: @@ -116,10 +169,14 @@ def run_with_hypothesis(**kwargs: Any) -> None: funcargs = pyfuncitem.funcargs testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} with get_runner(backend_name, backend_options) as runner: - runner.run_test(pyfuncitem.obj, testargs) + context_wrapped = context_wrap_async(context_like, pyfuncitem.obj) + runner.run_test(context_wrapped, testargs) return True + if not iscoroutinefunction(pyfuncitem.obj): + pyfuncitem.obj = context_wrap(context_like, pyfuncitem.obj) + return None diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index e98dbfec..006c1e56 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -380,3 +380,62 @@ async def test_anyio_mark_last_fail(x): result.assert_outcomes( passed=2 * len(get_all_backends()), xfailed=2 * len(get_all_backends()) ) + + +def test_all_tests_and_fixtures_run_in_same_context(testdir: Pytester) -> None: + testdir.makepyfile( + """ + import pytest + import contextvars + + async_fixture_var = contextvars.ContextVar("async_fixture_var") + sync_fixture_var = contextvars.ContextVar("sync_fixture_var") + + + @pytest.fixture + async def async_var_setter(): + reset_token = async_fixture_var.set("set") + yield + async_fixture_var.reset(reset_token) + + + @pytest.fixture + def sync_var_setter(): + reset_token = sync_fixture_var.set("set") + yield + sync_fixture_var.reset(reset_token) + + + @pytest.mark.anyio + async def test_async_func_async_then_sync_fixture( + async_var_setter, sync_var_setter + ): + assert "set" == sync_fixture_var.get() + assert "set" == async_fixture_var.get() + + + @pytest.mark.anyio + async def test_async_func_sync_then_async_fixture( + sync_var_setter, async_var_setter + ): + assert "set" == sync_fixture_var.get() + assert "set" == async_fixture_var.get() + + + def test_sync_func_async_then_sync_fixture( + anyio_backend_name, async_var_setter, sync_var_setter + ): + assert "set" == sync_fixture_var.get() + assert "set" == async_fixture_var.get() + + + def test_sync_func_sync_then_async_fixture( + anyio_backend_name, sync_var_setter, async_var_setter + ): + assert "set" == sync_fixture_var.get() + assert "set" == async_fixture_var.get() + """ + ) + + result = testdir.runpytest(*pytest_args) + result.assert_outcomes(passed=4 * len(get_all_backends()))