Skip to content

Commit

Permalink
Updated pytest plugin to run all async/sync fixture/test code on a si…
Browse files Browse the repository at this point in the history
…ngle Context

Fixes agronholm#614
  • Loading branch information
jhominal committed Nov 6, 2023
1 parent 9b4e8d9 commit fd724e0
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 4 deletions.
65 changes: 61 additions & 4 deletions src/anyio/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,52 @@

from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import Context, copy_context
from inspect import isasyncgenfunction, iscoroutinefunction
from typing import Any, Dict, Tuple, cast

import pytest
import sniffio
from _pytest.stash import StashKey

from ._core._eventloop import get_all_backends, get_async_backend
from ._run_in_context import ContextLike, context_wrap, context_wrap_async
from .abc import TestRunner

_current_runner: TestRunner | None = None
contextvars_context_key = StashKey[Context]()
_test_context_like_key = StashKey[ContextLike]()


class _TestContext(ContextLike):
"""This class manages transmission of sniffio.current_async_library_cvar"""

def __init__(self, context: Context):
self._context = context

def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any:
return self._context.run(
self._set_context_and_run,
sniffio.current_async_library_cvar.get(None),
func,
*args,
**kwargs,
)

def _set_context_and_run(
self, current_async_library: str | None, func: Any, /, *args: Any, **kwargs: Any
) -> Any:
reset_sniffio = None
if current_async_library is not None:
reset_sniffio = sniffio.current_async_library_cvar.set(
current_async_library
)

try:
return func(*args, **kwargs)
finally:
if reset_sniffio is not None:
sniffio.current_async_library_cvar.reset(reset_sniffio)


def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
Expand Down Expand Up @@ -59,27 +95,41 @@ def pytest_configure(config: Any) -> None:
)


def pytest_sessionstart(session: pytest.Session) -> None:
context = copy_context()
session.stash[contextvars_context_key] = context
session.stash[_test_context_like_key] = _TestContext(context)


def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
context_like: ContextLike = request.session.stash[_test_context_like_key]

def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
backend_name, backend_options = extract_backend_and_options(anyio_backend)
if has_backend_arg:
kwargs["anyio_backend"] = anyio_backend

with get_runner(backend_name, backend_options) as runner:
context_wrapped = context_wrap_async(context_like, func)
if isasyncgenfunction(func):
yield from runner.run_asyncgen_fixture(func, kwargs)
yield from runner.run_asyncgen_fixture(context_wrapped, kwargs)
else:
yield runner.run_fixture(func, kwargs)
yield runner.run_fixture(context_wrapped, kwargs)

# Only apply this to coroutine functions and async generator functions in requests
# that involve the anyio_backend fixture
func = fixturedef.func
if isasyncgenfunction(func) or iscoroutinefunction(func):
if "anyio_backend" in request.fixturenames:
has_backend_arg = "anyio_backend" in fixturedef.argnames
setattr(wrapper, "_runs_in_session_context", True)
fixturedef.func = wrapper
if not has_backend_arg:
fixturedef.argnames += ("anyio_backend",)
elif not getattr(func, "_runs_in_session_context", False):
wrapper = context_wrap(context_like, func)
setattr(wrapper, "_runs_in_session_context", True)
fixturedef.func = wrapper


@pytest.hookimpl(tryfirst=True)
Expand All @@ -95,9 +145,12 @@ def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
context_like: ContextLike = pyfuncitem.session.stash[_test_context_like_key]

def run_with_hypothesis(**kwargs: Any) -> None:
with get_runner(backend_name, backend_options) as runner:
runner.run_test(original_func, kwargs)
context_wrapped = context_wrap_async(context_like, original_func)
runner.run_test(context_wrapped, kwargs)

backend = pyfuncitem.funcargs.get("anyio_backend")
if backend:
Expand All @@ -116,10 +169,14 @@ def run_with_hypothesis(**kwargs: Any) -> None:
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
with get_runner(backend_name, backend_options) as runner:
runner.run_test(pyfuncitem.obj, testargs)
context_wrapped = context_wrap_async(context_like, pyfuncitem.obj)
runner.run_test(context_wrapped, testargs)

return True

if not iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = context_wrap(context_like, pyfuncitem.obj)

return None


Expand Down
59 changes: 59 additions & 0 deletions tests/test_pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,62 @@ async def test_anyio_mark_last_fail(x):
result.assert_outcomes(
passed=2 * len(get_all_backends()), xfailed=2 * len(get_all_backends())
)


def test_all_tests_and_fixtures_run_in_same_context(testdir: Pytester) -> None:
testdir.makepyfile(
"""
import pytest
import contextvars
async_fixture_var = contextvars.ContextVar("async_fixture_var")
sync_fixture_var = contextvars.ContextVar("sync_fixture_var")
@pytest.fixture
async def async_var_setter():
reset_token = async_fixture_var.set("set")
yield
async_fixture_var.reset(reset_token)
@pytest.fixture
def sync_var_setter():
reset_token = sync_fixture_var.set("set")
yield
sync_fixture_var.reset(reset_token)
@pytest.mark.anyio
async def test_async_func_async_then_sync_fixture(
async_var_setter, sync_var_setter
):
assert "set" == sync_fixture_var.get()
assert "set" == async_fixture_var.get()
@pytest.mark.anyio
async def test_async_func_sync_then_async_fixture(
sync_var_setter, async_var_setter
):
assert "set" == sync_fixture_var.get()
assert "set" == async_fixture_var.get()
def test_sync_func_async_then_sync_fixture(
anyio_backend_name, async_var_setter, sync_var_setter
):
assert "set" == sync_fixture_var.get()
assert "set" == async_fixture_var.get()
def test_sync_func_sync_then_async_fixture(
anyio_backend_name, sync_var_setter, async_var_setter
):
assert "set" == sync_fixture_var.get()
assert "set" == async_fixture_var.get()
"""
)

result = testdir.runpytest(*pytest_args)
result.assert_outcomes(passed=4 * len(get_all_backends()))

0 comments on commit fd724e0

Please sign in to comment.