Skip to content

Commit

Permalink
use asserts and type check instead of decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
dansan committed Nov 28, 2019
1 parent eb87307 commit f4f4cd1
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions starlette/routing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import functools
import inspect
import re
import traceback
Expand Down Expand Up @@ -29,20 +28,6 @@ class Match(Enum):
FULL = 2


def verify_args_is_only_route(func: typing.Callable) -> typing.Callable:
"""Raise TypeError if number of positional arguments is not exactly 1."""

@functools.wraps(func)
def wrapper(self: BaseRoute, *args: str, **kwargs: str) -> URLPath:
if len(args) < 1:
raise TypeError("Missing route name as the first argument.")
if len(args) > 1:
raise TypeError("Invalid positional argument passed.")
return func(self, *args, **kwargs)

return wrapper


def request_response(func: typing.Callable) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
Expand Down Expand Up @@ -218,8 +203,8 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Sole positional argument must be the route name."
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())

Expand Down Expand Up @@ -282,8 +267,8 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Sole positional argument must be the route name."
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())

Expand Down Expand Up @@ -356,8 +341,8 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Sole positional argument must be the route name."
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
# 'name' matches "<mount_name>".
Expand Down Expand Up @@ -428,8 +413,8 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Sole positional argument must be the route name."
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
# 'name' matches "<mount_name>".
Expand Down Expand Up @@ -498,8 +483,9 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)

@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
if len(args) != 1 or len(args) == 1 and not isinstance(args[0], str):
raise TypeError("Sole positional argument must be the route name.")
for route in self.routes:
try:
return route.url_path_for(args[0], **kwargs)
Expand Down

0 comments on commit f4f4cd1

Please sign in to comment.