Skip to content

Commit

Permalink
Fix functools.partial async handlers for classmethods (#1106)
Browse files Browse the repository at this point in the history
* Showcase the bug

* Fixed functools.partial usage with classmethods

* Updated comment

* Updated docstring according to suggestion

Co-authored-by: Jamie Hewland <jhewland@gmail.com>

Co-authored-by: Jamie Hewland <jhewland@gmail.com>
  • Loading branch information
vladmunteanu and JayH5 committed Feb 2, 2021
1 parent 62e95b8 commit fe908b1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
8 changes: 3 additions & 5 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import inspect
import re
import sys
import traceback
import typing
from enum import Enum
Expand Down Expand Up @@ -33,11 +32,10 @@ class Match(Enum):
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:
"""
Correctly determines if an object is a coroutine function,
with a fix for partials on Python < 3.8.
including those wrapped in functools.partial objects.
"""
if sys.version_info < (3, 8): # pragma: no cover
while isinstance(obj, functools.partial):
obj = obj.func
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj)


Expand Down
23 changes: 21 additions & 2 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,35 @@ def run_shutdown():
pass # pragma: nocover


class AsyncEndpointClassMethod:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})


async def _partial_async_endpoint(arg, request):
return JSONResponse({"arg": arg})


partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
partial_cls_async_endpoint = functools.partial(
AsyncEndpointClassMethod.async_endpoint, "foo"
)

partial_async_app = Router(routes=[Route("/", partial_async_endpoint)])
partial_async_app = Router(
routes=[
Route("/", partial_async_endpoint),
Route("/cls", partial_cls_async_endpoint),
]
)


def test_partial_async_endpoint():
response = TestClient(partial_async_app).get("/")
test_client = TestClient(partial_async_app)
response = test_client.get("/")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}

cls_method_response = test_client.get("/cls")
assert cls_method_response.status_code == 200
assert cls_method_response.json() == {"arg": "foo"}

0 comments on commit fe908b1

Please sign in to comment.