Skip to content

Commit

Permalink
♻️ Refactor logic to handle root_path to keep compatibility with AS…
Browse files Browse the repository at this point in the history
…GI and compatibility with other non-Starlette-specific libraries like a2wsgi (#2400)

* ✨ Add util to get route path from scope

* ♻️ Refactor extracting the local route_path from the scope, and creating the child scope with its own root_pat, this allows sub-apps (e.g. WSGIMiddleware) to know
where it was mounted, and from which path prefix starts the path its sub-app should handle

* ♻️ Refactor datastructures and request to be conformant with the ASGI spec, respecting root_path

* ✅ Add and update tests for root_path with mounted apps that don't know about Starlette internals (e.g. the route_root_path extension scope key that was added in

* ✅ Update test for root_path, TestClient was not requiring paths under a root_path to pass the root_path, which is what clients would have to do if the app is
mounted.

* 🎨 Fix formatting

* 🎨 Remove type ignore

* 🔥 Remove unnecessary comment

* ✨ Update (deprecated) WSGIMiddleware to be compatible with the updated root_path, taking pieces from a2wsgi

* 🎨 Fix types

* ✅ Update test for WSGIMiddleware with root_path

* 🔥 Remove logic/features not in the original (deprecated) WSGIMiddleware

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Aber <me@abersheeran.com>
  • Loading branch information
3 people committed Jan 9, 2024
1 parent 8f2307d commit c3c6314
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 37 deletions.
9 changes: 9 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import functools
import re
import sys
import typing
from contextlib import contextmanager

from starlette.types import Scope

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
Expand Down Expand Up @@ -86,3 +89,9 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
exc = exc.exceptions[0] # pragma: no cover

raise exc


def get_route_path(scope: Scope) -> str:
root_path = scope.get("root_path", "")
route_path = re.sub(r"^" + root_path, "", scope["path"])
return route_path
2 changes: 1 addition & 1 deletion starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope.get("root_path", "") + scope["path"]
path = scope["path"]
query_string = scope.get("query_string", b"")

host_header = None
Expand Down
10 changes: 8 additions & 2 deletions starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""

script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]

environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
"PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
"wsgi.version": (1, 0),
Expand Down
13 changes: 11 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,18 @@ def url(self) -> URL:
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
path = app_root_path
if not path.endswith("/"):
path += "/"
base_url_scope["path"] = path
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get("root_path", "")
base_url_scope["root_path"] = app_root_path
self._base_url = URL(scope=base_url_scope)
return self._base_url

Expand Down
41 changes: 22 additions & 19 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum

from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette._utils import get_route_path, is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
Expand Down Expand Up @@ -255,9 +255,8 @@ def __init__(
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "http":
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -345,9 +344,8 @@ def __init__(
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "websocket":
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -420,9 +418,8 @@ def routes(self) -> typing.List[BaseRoute]:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] in ("http", "websocket"):
path = scope["path"]
root_path = scope.get("route_root_path", scope.get("root_path", ""))
route_path = scope.get("route_path", re.sub(r"^" + root_path, "", path))
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
Expand All @@ -432,11 +429,20 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"route_root_path": root_path + matched_path,
"route_path": remaining_path,
# app_root_path will only be set at the top level scope,
# initialized with the (optional) value of a root_path
# set above/before Starlette. And even though any
# mount will have its own child scope with its own respective
# root_path, the app_root_path will always be available in all
# the child scopes with the same top level value because it's
# set only once here with a default, any other child scope will
# just inherit that app_root_path default value stored in the
# scope. All this is needed to support Request.url_for(), as it
# uses the app_root_path to build the URL path.
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
Expand Down Expand Up @@ -787,15 +793,12 @@ async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
await partial.handle(scope, receive, send)
return

root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
if scope["type"] == "http" and self.redirect_slashes and path != "/":
route_path = get_route_path(scope)
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
redirect_scope = dict(scope)
if path.endswith("/"):
redirect_scope["route_path"] = path.rstrip("/")
if route_path.endswith("/"):
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["route_path"] = path + "/"
redirect_scope["path"] = redirect_scope["path"] + "/"

for route in self.routes:
Expand Down
7 changes: 3 additions & 4 deletions starlette/staticfiles.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import importlib.util
import os
import re
import stat
import typing
from email.utils import parsedate

import anyio
import anyio.to_thread

from starlette._utils import get_route_path
from starlette.datastructures import URL, Headers
from starlette.exceptions import HTTPException
from starlette.responses import FileResponse, RedirectResponse, Response
Expand Down Expand Up @@ -110,9 +110,8 @@ def get_path(self, scope: Scope) -> str:
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
return os.path.normpath(os.path.join(*path.split("/"))) # type: ignore[no-any-return] # noqa: E501
route_path = get_route_path(scope)
return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501

async def get_response(self, path: str, scope: Scope) -> Response:
"""
Expand Down
5 changes: 3 additions & 2 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_build_environ():
"http_version": "1.1",
"method": "GET",
"scheme": "https",
"path": "/",
"path": "/sub/",
"root_path": "/sub",
"query_string": b"a=123&b=456",
"headers": [
(b"host", b"www.example.org"),
Expand All @@ -117,7 +118,7 @@ def test_build_environ():
"QUERY_STRING": "a=123&b=456",
"REMOTE_ADDR": "134.56.78.4",
"REQUEST_METHOD": "GET",
"SCRIPT_NAME": "",
"SCRIPT_NAME": "/sub",
"SERVER_NAME": "www.example.org",
"SERVER_PORT": 443,
"SERVER_PROTOCOL": "HTTP/1.1",
Expand Down
40 changes: 33 additions & 7 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import functools
import json
import typing
import uuid

Expand Down Expand Up @@ -563,12 +564,12 @@ def test_url_for_with_root_path(test_client_factory):
client = test_client_factory(
app, base_url="https://www.example.org/", root_path="/sub_path"
)
response = client.get("/")
response = client.get("/sub_path/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
}
response = client.get("/submount/")
response = client.get("/sub_path/submount/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
Expand Down Expand Up @@ -1242,23 +1243,37 @@ async def echo_paths(request: Request, name: str):
)


async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str):
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
content = json.dumps(data).encode("utf-8")
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"application/json")],
}
)
await send({"type": "http.response.body", "body": content})


echo_paths_routes = [
Route(
"/path",
functools.partial(echo_paths, name="path"),
name="path",
methods=["GET"],
),
Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
Mount(
"/root",
"/sub",
name="mount",
routes=[
Route(
"/path",
functools.partial(echo_paths, name="subpath"),
name="subpath",
methods=["GET"],
)
),
],
),
]
Expand All @@ -1276,11 +1291,22 @@ def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClie
"path": "/root/path",
"root_path": "/root",
}
response = client.get("/root/asgipath/")
assert response.status_code == 200
assert response.json() == {
"name": "asgipath",
"path": "/root/asgipath/",
# Things that mount other ASGI apps, like WSGIMiddleware, would not be aware
# of the prefixed path, and would have their own notion of their own paths,
# so they need to be able to rely on the root_path to know the location they
# are mounted on
"root_path": "/root/asgipath",
}

response = client.get("/root/root/path")
response = client.get("/root/sub/path")
assert response.status_code == 200
assert response.json() == {
"name": "subpath",
"path": "/root/root/path",
"root_path": "/root",
"path": "/root/sub/path",
"root_path": "/root/sub",
}

0 comments on commit c3c6314

Please sign in to comment.