Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop caching root url #7374

Merged
merged 11 commits into from Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/soft-lies-carry.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Stop caching root url
16 changes: 10 additions & 6 deletions gradio/route_utils.py
Expand Up @@ -261,14 +261,18 @@ async def call_process_api(
return output


def strip_url(orig_url: str) -> str:
def get_root_url(request: fastapi.Request) -> str:
"""
Strips the query parameters and trailing slash from a URL.
Gets the root url of the request, stripping off any query parameters and trailing slashes.
Also ensures that the root url is https if the request is https.
"""
parsed_url = httpx.URL(orig_url)
stripped_url = parsed_url.copy_with(query=None)
stripped_url = str(stripped_url)
return stripped_url.rstrip("/")
root_url = str(request.url)
root_url = httpx.URL(root_url)
root_url = root_url.copy_with(query=None)
root_url = str(root_url)
if request.headers.get("x-forwarded-proto") == "https":
root_url = root_url.replace("http://", "https://")
return root_url.rstrip("/")


def _user_safe_decode(src: bytes, codec: str) -> str:
Expand Down
29 changes: 14 additions & 15 deletions gradio/routes.py
Expand Up @@ -5,6 +5,7 @@

import asyncio
import contextlib
import copy
import sys

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -310,18 +311,17 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
root_path = route_utils.strip_url(str(request.url))
root_path = route_utils.get_root_url(request)
if app.auth is None or user is not None:
config = app.get_blocks().config
if "root" not in config:
config["root"] = root_path
config = add_root_url(config, root_path)
config = copy.deepcopy(app.get_blocks().config)
config["root"] = root_path
config = add_root_url(config, root_path)
else:
config = {
"auth_required": True,
"auth_message": blocks.auth_message,
"space_id": app.get_blocks().space_id,
"root": route_utils.strip_url(root_path),
"root": root_path,
}

try:
Expand Down Expand Up @@ -352,11 +352,10 @@ def api_info():
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
root_path = route_utils.strip_url(str(request.url))[:-7]
config = app.get_blocks().config
if "root" not in config:
config["root"] = route_utils.strip_url(root_path)
config = add_root_url(config, root_path)
config = copy.deepcopy(app.get_blocks().config)
root_path = route_utils.get_root_url(request)[: -len("/config")]
config["root"] = root_path
config = add_root_url(config, root_path)
return config

@app.get("/static/{path:path}")
Expand Down Expand Up @@ -571,8 +570,8 @@ async def predict(
content={"error": str(error) if show_error else None},
status_code=500,
)
root_path = app.get_blocks().config.get("root", "")
output = add_root_url(output, route_utils.strip_url(root_path))
root_path = route_utils.get_root_url(request)[: -len(f"/api/{api_name}")]
Copy link
Member Author

@abidlabs abidlabs Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annoying to do this manually, but if you use e.g. request.url.path, it also includes the root url if gradio is mounted within another fastapi app, which we don't want. So I think the most reliable solution is to just remove the suffix manually.

output = add_root_url(output, root_path)
return output

@app.get("/queue/data", dependencies=[Depends(login_check)])
Expand All @@ -581,7 +580,7 @@ async def queue_data(
session_hash: str,
):
blocks = app.get_blocks()
root_path = app.get_blocks().config.get("root", "")
root_path = route_utils.get_root_url(request)[: -len("/queue/data")]

async def sse_stream(request: fastapi.Request):
try:
Expand Down Expand Up @@ -627,7 +626,7 @@ async def sse_stream(request: fastapi.Request):
"success": False,
}
if message:
add_root_url(message, route_utils.strip_url(root_path))
add_root_url(message, root_path)
yield f"data: {json.dumps(message)}\n\n"
if message["msg"] == ServerMessage.process_completed:
blocks._queue.pending_event_ids_session[
Expand Down