Skip to content

Commit

Permalink
allow 'name' in key word arguments of url_for() and url_path_for() (#608
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dansan committed Jan 1, 2022
1 parent 29c1e01 commit 6caeca9
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 14 deletions.
4 changes: 2 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def debug(self, value: bool) -> None:
self._debug = value
self.middleware_stack = self.build_middleware_stack()

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(name, **path_params)
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(*args, **path_params)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
Expand Down
5 changes: 3 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,10 @@ def state(self) -> State:
self._state = State(self.scope["state"])
return self._state

def url_for(self, name: str, **path_params: typing.Any) -> str:
def url_for(self, *args: str, **path_params: typing.Any) -> str:
assert len(args) == 1, "url_for() takes exactly one positional argument"
router: Router = self.scope["router"]
url_path = router.url_path_for(name, **path_params)
url_path = router.url_path_for(*args, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)


Expand Down
24 changes: 16 additions & 8 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class NoMatchFound(Exception):
"""
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
Raised by `.url_for(*args, **path_params)` and `.url_path_for(*args, **path_params)`
if no matching route exists.
"""

Expand Down Expand Up @@ -156,7 +156,7 @@ class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
raise NotImplementedError() # pragma: no cover

async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down Expand Up @@ -235,7 +235,9 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
assert len(args) == 1, "url_path_for() takes exactly one positional argument"
name = args[0]
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())

Expand Down Expand Up @@ -301,7 +303,9 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
assert len(args) == 1, "url_path_for() takes exactly one positional argument"
name = args[0]
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())

Expand Down Expand Up @@ -374,7 +378,9 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
assert len(args) == 1, "url_path_for() takes exactly one positional argument"
name = args[0]
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
Expand Down Expand Up @@ -444,7 +450,9 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
assert len(args) == 1, "url_path_for() takes exactly one positional argument"
name = args[0]
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
Expand Down Expand Up @@ -584,10 +592,10 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
def url_path_for(self, *args: str, **path_params: typing.Any) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
return route.url_path_for(*args, **path_params)
except NoMatchFound:
pass
raise NoMatchFound()
Expand Down
7 changes: 6 additions & 1 deletion starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def _create_env(
self, directory: typing.Union[str, PathLike]
) -> "jinja2.Environment":
@pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
def url_for(context: dict, *args: str, **path_params: typing.Any) -> str:
if len(args) < 1:
raise TypeError("Missing route name as the second argument.")
elif len(args) > 1:
raise TypeError("Invalid positional argument passed.")
name = args[0]
request = context["request"]
return request.url_for(name, **path_params)

Expand Down
19 changes: 19 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anyio
import pytest

from starlette.applications import Starlette
from starlette.requests import ClientDisconnect, Request, State
from starlette.responses import JSONResponse, Response

Expand Down Expand Up @@ -397,6 +398,24 @@ async def app(scope, receive, send):
assert result["cookies"] == expected


def test_request_url_for_allows_name_arg(test_client_factory):
app = Starlette()

@app.route("/users/{name}")
async def func_users(request):
raise NotImplementedError() # pragma: no cover

@app.route("/test")
async def func_url_for_test(request: Request):
url = request.url_for("func_users", name="abcde")
return Response(str(url), media_type="text/plain")

client = test_client_factory(app)
response = client.get("/test")
assert response.status_code == 200
assert response.text == "http://testserver/users/abcde"


def test_chunked_encoding(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def user(request):
return Response(content, media_type="text/plain")


def user2(request):
content = "User 2 " + request.path_params["name"]
return Response(content, media_type="text/plain")


def user_me(request):
content = "User fixed me"
return Response(content, media_type="text/plain")
Expand Down Expand Up @@ -63,6 +68,7 @@ async def async_ws_endpoint(cls, websocket: WebSocket):
Route("/", endpoint=users),
Route("/me", endpoint=user_me),
Route("/{username}", endpoint=user),
Route("/n/{name}", endpoint=user2),
Route("/nomatch", endpoint=user_no_match),
],
),
Expand Down Expand Up @@ -174,6 +180,10 @@ def test_router(client):
assert response.status_code == 200
assert response.text == "User tomchristie"

response = client.get("/users/n/tomchristie")
assert response.status_code == 200
assert response.text == "User 2 tomchristie"

response = client.get("/users/me")
assert response.status_code == 200
assert response.text == "User fixed me"
Expand Down Expand Up @@ -236,14 +246,23 @@ def test_route_converters(client):

def test_url_path_for():
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("user", username="tomchristie1") == "/users/tomchristie1"
assert app.url_path_for("user2", name="tomchristie2") == "/users/n/tomchristie2"
with pytest.raises(NoMatchFound):
assert app.url_path_for("user", name="tomchristie1")
with pytest.raises(NoMatchFound):
assert app.url_path_for("user2", username="tomchristie2")
assert app.url_path_for("websocket_endpoint") == "/ws"
with pytest.raises(NoMatchFound):
assert app.url_path_for("broken")
with pytest.raises(AssertionError):
app.url_path_for("user", username="tom/christie")
with pytest.raises(AssertionError):
app.url_path_for("user", username="")
with pytest.raises(AssertionError, match="takes exactly one positional argument"):
assert app.url_path_for("user", "args2", name="tomchristie1")
with pytest.raises(AssertionError, match="takes exactly one positional argument"):
assert app.url_path_for(name="tomchristie1")


def test_url_for():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ def test_template_response_requires_request(tmpdir):
templates = Jinja2Templates(str(tmpdir))
with pytest.raises(ValueError):
templates.TemplateResponse(None, {})


def test_template_env_url_for_args(tmpdir):
templates = Jinja2Templates(directory=str(tmpdir))

url_for_func = templates.env.globals["url_for"]
with pytest.raises(TypeError, match="Invalid positional argument passed."):
assert url_for_func({}, "user", "args2", name="tomchristie")
with pytest.raises(TypeError, match="Missing route name as the second argument."):
assert url_for_func({}, name="tomchristie")

0 comments on commit 6caeca9

Please sign in to comment.