Skip to content

Commit

Permalink
Set root correctly for Gradio apps that are deployed behind reverse…
Browse files Browse the repository at this point in the history
… proxies (#7411)

* testing

* add changeset

* test

* backend

* fix

* add unit tests

* testing

* remove check

* add changeset

* trying something

* add changeset

* override

* add changeset

* fix

* fix

* clean

* lint

* route utils

* add changeset

* changes

* add changeset

* test

* revert testing

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Feb 14, 2024
1 parent 065c5b1 commit 32b317f
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 11 deletions.
6 changes: 6 additions & 0 deletions .changeset/tricky-coins-sniff.md
@@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---

fix:Set `root` correctly for Gradio apps that are deployed behind reverse proxies
3 changes: 3 additions & 0 deletions client/js/src/client.ts
Expand Up @@ -301,6 +301,9 @@ export function api_factory(

async function config_success(_config: Config): Promise<client_return> {
config = _config;
if (window.location.protocol === "https:") {
config.root = config.root.replace("http://", "https://");
}
api_map = map_names_to_ids(_config?.dependencies || []);
if (config.auth_required) {
return {
Expand Down
16 changes: 11 additions & 5 deletions gradio/route_utils.py
Expand Up @@ -261,18 +261,24 @@ async def call_process_api(
return output


def get_root_url(request: fastapi.Request) -> str:
def get_root_url(
request: fastapi.Request, route_path: str, root_path: str | None
) -> str:
"""
Gets the root url of the request, stripping off any query parameters and trailing slashes.
Also ensures that the root url is https if the request is https.
Gets the root url of the request, stripping off any query parameters, the route_path, and trailing slashes.
Also ensures that the root url is https if the request is https. If root_path is provided, it is appended to the root url.
The final root url will not have a trailing slash.
"""
root_url = str(request.url)
root_url = httpx.URL(root_url)
root_url = root_url.copy_with(query=None)
root_url = str(root_url)
root_url = str(root_url).rstrip("/")
if request.headers.get("x-forwarded-proto") == "https":
root_url = root_url.replace("http://", "https://")
return root_url.rstrip("/")
route_path = route_path.rstrip("/")
if len(route_path) > 0:
root_url = root_url[: -len(route_path)]
return (root_url.rstrip("/") + (root_path or "")).rstrip("/")


def _user_safe_decode(src: bytes, codec: str) -> str:
Expand Down
16 changes: 12 additions & 4 deletions gradio/routes.py
Expand Up @@ -311,7 +311,9 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
root_path = route_utils.get_root_url(request)
root_path = route_utils.get_root_url(
request=request, route_path="/", root_path=app.root_path
)
if app.auth is None or user is not None:
config = copy.deepcopy(app.get_blocks().config)
config["root"] = root_path
Expand Down Expand Up @@ -353,7 +355,9 @@ def api_info():
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
config = copy.deepcopy(app.get_blocks().config)
root_path = route_utils.get_root_url(request)[: -len("/config")]
root_path = route_utils.get_root_url(
request=request, route_path="/config", root_path=app.root_path
)
config["root"] = root_path
config = add_root_url(config, root_path)
return config
Expand Down Expand Up @@ -570,7 +574,9 @@ async def predict(
content={"error": str(error) if show_error else None},
status_code=500,
)
root_path = route_utils.get_root_url(request)[: -len(f"/api/{api_name}")]
root_path = route_utils.get_root_url(
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
)
output = add_root_url(output, root_path)
return output

Expand All @@ -580,7 +586,9 @@ async def queue_data(
session_hash: str,
):
blocks = app.get_blocks()
root_path = route_utils.get_root_url(request)[: -len("/queue/data")]
root_path = route_utils.get_root_url(
request=request, route_path="/queue/data", root_path=app.root_path
)

async def sse_stream(request: fastapi.Request):
try:
Expand Down
51 changes: 49 additions & 2 deletions test/test_routes.py
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import pytest
import starlette.routing
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from gradio_client import media_data

Expand All @@ -25,7 +25,7 @@
routes,
wasm_utils,
)
from gradio.route_utils import FnIndexInferError
from gradio.route_utils import FnIndexInferError, get_root_url


@pytest.fixture()
Expand Down Expand Up @@ -862,3 +862,50 @@ def test_component_server_endpoints(connect):
},
)
assert fail_req.status_code == 404


@pytest.mark.parametrize(
"request_url, route_path, root_path, expected_root_url",
[
("http://localhost:7860/", "/", None, "http://localhost:7860"),
(
"http://localhost:7860/demo/test",
"/demo/test",
None,
"http://localhost:7860",
),
(
"http://localhost:7860/demo/test/",
"/demo/test",
None,
"http://localhost:7860",
),
(
"http://localhost:7860/demo/test?query=1",
"/demo/test",
None,
"http://localhost:7860",
),
(
"http://localhost:7860/demo/test?query=1",
"/demo/test/",
"/gradio/",
"http://localhost:7860/gradio",
),
(
"http://localhost:7860/demo/test?query=1",
"/demo/test",
"/gradio/",
"http://localhost:7860/gradio",
),
(
"https://localhost:7860/demo/test?query=1",
"/demo/test",
"/gradio/",
"https://localhost:7860/gradio",
),
],
)
def test_get_root_url(request_url, route_path, root_path, expected_root_url):
request = Request({"path": request_url, "type": "http", "headers": {}})
assert get_root_url(request, route_path, root_path) == expected_root_url

0 comments on commit 32b317f

Please sign in to comment.