Skip to content

Commit

Permalink
Merge pull request #79 from ikalnytskyi/bug/pass-with-coroutines
Browse files Browse the repository at this point in the history
Keep an async function marker when `@pass_` is used
  • Loading branch information
ikalnytskyi committed Nov 27, 2023
2 parents fbc02ea + 13a7711 commit c1c6725
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 2 deletions.
9 changes: 9 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ Release Notes
backward incompatible changes will be released along with bumping major
version component.

4.1.0
`````

(unreleased)

* Fix a bug when a coroutine function wrapped with ``@picobox.pass_()``
lost its coroutine function marker, i.e. ``inspect.iscoroutinefunction()``
returned ``False``.

4.0.0
`````

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Bugs = "https://github.com/ikalnytskyi/picobox/issues"
source = "vcs"

[tool.hatch.envs.test]
dependencies = ["pytest", "flask"]
dependencies = ["pytest", "pytest-asyncio", "flask"]
scripts.run = "python -m pytest --strict-markers {args:-vv}"

[tool.hatch.envs.lint]
Expand Down
10 changes: 9 additions & 1 deletion src/picobox/_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def decorator(fn):
return fn

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def fn_with_dependencies(*args, **kwargs):
signature = inspect.signature(fn)
arguments = signature.bind_partial(*args, **kwargs)

Expand All @@ -203,6 +203,14 @@ def wrapper(*args, **kwargs):
kwargs[as_] = self.get(key)
return fn(*args, **kwargs)

if inspect.iscoroutinefunction(fn):

@functools.wraps(fn)
async def wrapper(*args, **kwargs):
return await fn_with_dependencies(*args, **kwargs)
else:
wrapper = fn_with_dependencies

wrapper.__dependencies__ = [(key, as_)]
return wrapper

Expand Down
43 changes: 43 additions & 0 deletions tests/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,27 @@ def __init__(self, x):
assert Foo(*args, **kwargs).x == rv


@pytest.mark.asyncio()
@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
((1,), {}, 1),
((), {"x": 1}, 1),
((), {}, 42),
],
)
async def test_box_pass_coroutine(args, kwargs, rv, boxclass):
testbox = boxclass()
testbox.put("x", 42)

@testbox.pass_("x")
async def co(x):
return x

assert inspect.iscoroutinefunction(co)
assert await co(*args, **kwargs) == rv


@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
Expand Down Expand Up @@ -490,6 +511,28 @@ def fn(a, b, c, d):
assert len(fn()) == 3


@pytest.mark.asyncio()
async def test_box_pass_optimization_async(boxclass, request):
testbox = boxclass()
testbox.put("a", 1)
testbox.put("b", 1)
testbox.put("d", 1)

@testbox.pass_("a")
@testbox.pass_("b")
@testbox.pass_("d", as_="c")
async def fn(a, b, c):
backtrace = list(
itertools.dropwhile(
lambda frame: frame[2] != request.function.__name__,
traceback.extract_stack(),
)
)
return backtrace[1:-1]

assert len(await fn()) == 1


def test_chainbox_put_changes_box():
testbox = picobox.Box()
testchainbox = picobox.ChainBox(testbox)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test picobox's stack interface."""

import inspect
import itertools
import sys
import traceback
Expand Down Expand Up @@ -449,6 +450,28 @@ def __init__(self, x):
assert Foo(*args, **kwargs).x == rv


@pytest.mark.asyncio()
@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
((1,), {}, 1),
((), {"x": 1}, 1),
((), {}, 42),
],
)
async def test_box_pass_coroutine(boxclass, teststack, args, kwargs, rv):
testbox = boxclass()
testbox.put("x", 42)

@teststack.pass_("x")
async def co(x):
return x

with teststack.push(testbox):
assert inspect.iscoroutinefunction(co)
assert await co(*args, **kwargs) == rv


@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
Expand Down Expand Up @@ -567,6 +590,29 @@ def fn(a, b, c, d):
assert len(fn()) == 3


@pytest.mark.asyncio()
async def test_box_pass_optimization_async(boxclass, teststack, request):
testbox = boxclass()
testbox.put("a", 1)
testbox.put("b", 1)
testbox.put("d", 1)

@teststack.pass_("a")
@teststack.pass_("b")
@teststack.pass_("d", as_="c")
async def fn(a, b, c):
backtrace = list(
itertools.dropwhile(
lambda frame: frame[2] != request.function.__name__,
traceback.extract_stack(),
)
)
return backtrace[1:-1]

with teststack.push(testbox):
assert len(await fn()) == 1


def test_chainbox_put_changes_box(teststack):
testbox = picobox.Box()
testchainbox = picobox.ChainBox(testbox)
Expand Down

0 comments on commit c1c6725

Please sign in to comment.