Skip to content

Commit

Permalink
Allow usage of functools.partial async handlers (#984)
Browse files Browse the repository at this point in the history
* Allow usage of async partial methods

* Added test for partial async endpoint

* Double quotes vs single quotes

* Support multiple levels of partials, check Python < 3.8

* Skip coverage for py3.8 branch

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
  • Loading branch information
vladmunteanu and florimondmanca committed Nov 8, 2020
1 parent 7a783d3 commit fe961dd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
20 changes: 18 additions & 2 deletions starlette/routing.py
@@ -1,6 +1,8 @@
import asyncio
import functools
import inspect
import re
import sys
import traceback
import typing
from enum import Enum
Expand Down Expand Up @@ -28,12 +30,23 @@ class Match(Enum):
FULL = 2


def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:
"""
Correctly determines if an object is a coroutine function,
with a fix for partials on Python < 3.8.
"""
if sys.version_info < (3, 8): # pragma: no cover
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj)


def request_response(func: typing.Callable) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
is_coroutine = asyncio.iscoroutinefunction(func)
is_coroutine = iscoroutinefunction_or_partial(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
Expand Down Expand Up @@ -169,7 +182,10 @@ def __init__(
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema

if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(request) -> response`.
self.app = request_response(endpoint)
if methods is None:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_routing.py
@@ -1,3 +1,4 @@
import functools
import uuid

import pytest
Expand Down Expand Up @@ -587,3 +588,18 @@ def run_shutdown():
with pytest.raises(RuntimeError):
with TestClient(app):
pass # pragma: nocover


async def _partial_async_endpoint(arg, request):
return JSONResponse({"arg": arg})


partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")

partial_async_app = Router(routes=[Route("/", partial_async_endpoint)])


def test_partial_async_endpoint():
response = TestClient(partial_async_app).get("/")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}

0 comments on commit fe961dd

Please sign in to comment.