diff --git a/.changeset/odd-parrots-open.md b/.changeset/odd-parrots-open.md new file mode 100644 index 000000000000..997f8050366f --- /dev/null +++ b/.changeset/odd-parrots-open.md @@ -0,0 +1,6 @@ +--- +"@gradio/client": patch +"gradio": patch +--- + +fix:Revert replica proxy logic and instead implement using the `root` variable diff --git a/client/js/src/client.ts b/client/js/src/client.ts index d32be1ab2325..b8a986c421da 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -8,7 +8,8 @@ import { get_space_hardware, set_space_hardware, set_space_timeout, - hardware_types + hardware_types, + resolve_root } from "./utils.js"; import type { @@ -436,7 +437,7 @@ export function api_factory( } handle_blob( - `${http_protocol}//${host + config.path}`, + `${http_protocol}//${resolve_root(host, config.path, true)}`, data, api_info, hf_token @@ -453,7 +454,7 @@ export function api_factory( }); post_data( - `${http_protocol}//${host + config.path}/run${ + `${http_protocol}//${resolve_root(host, config.path, true)}/run${ _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}` }${url_params ? "?" + url_params : ""}`, { @@ -521,8 +522,11 @@ export function api_factory( fn_index, time: new Date() }); - - let url = new URL(`${ws_protocol}://${host}${config.path} + let url = new URL(`${ws_protocol}://${resolve_root( + host, + config.path, + true + )} /queue/join${url_params ? "?" + url_params : ""}`); if (jwt) { @@ -686,7 +690,11 @@ export function api_factory( try { await fetch_implementation( - `${http_protocol}//${host + config.path}/reset`, + `${http_protocol}//${resolve_root( + host, + config.path, + true + )}/reset`, { headers: { "Content-Type": "application/json" }, method: "POST", @@ -1206,7 +1214,7 @@ async function resolve_config( ) { const path = window.gradio_config.root; const config = window.gradio_config; - config.root = endpoint + config.root; + config.root = resolve_root(endpoint, config.root, false); return { ...config, path: path }; } else if (endpoint) { let response = await fetch_implementation(`${endpoint}/config`, { diff --git a/client/js/src/utils.ts b/client/js/src/utils.ts index 605dc291461f..c7190433e25d 100644 --- a/client/js/src/utils.ts +++ b/client/js/src/utils.ts @@ -1,5 +1,26 @@ import type { Config } from "./types.js"; +/** + * This function is used to resolve the URL for making requests when the app has a root path. + * The root path could be a path suffix like "/app" which is appended to the end of the base URL. Or + * it could be a full URL like "https://abidlabs-test-client-replica--gqf2x.hf.space" which is used when hosting + * Gradio apps on Hugging Face Spaces. + * @param {string} base_url The base URL at which the Gradio server is hosted + * @param {string} root_path The root path, which could be a path suffix (e.g. mounted in FastAPI app) or a full URL (e.g. hosted on Hugging Face Spaces) + * @param {boolean} prioritize_base Whether to prioritize the base URL over the root path. This is used when both the base path and root paths are full URLs. For example, for fetching files the root path should be prioritized, but for making request, the base URL should be prioritized. + * @returns {string} the resolved URL + */ +export function resolve_root( + base_url: string, + root_path: string, + prioritize_base: boolean +): string { + if (root_path.startsWith("http://") || root_path.startsWith("https://")) { + return prioritize_base ? base_url : root_path; + } + return base_url + root_path; +} + export function determine_protocol(endpoint: string): { ws_protocol: "ws" | "wss"; http_protocol: "http:" | "https:"; diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 0d357cafc3fc..ce223eb30645 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -247,28 +247,11 @@ async def call_process_api( return output -def set_replica_url_in_config( - config: dict, replica_url: str, all_replica_urls: set[str] -) -> None: +def strip_url(orig_url: str) -> str: """ - If the Gradio app is running on Hugging Face Spaces and the machine has multiple replicas, - we pass in the direct URL to the replica so that we have the fully resolved path to any files - on that machine. This direct URL can be shared with other users and the path will still work. - - Parameters: - config: The config dictionary to modify. - replica_url: The direct URL to the replica. - all_replica_urls: The direct URLs to the other replicas. These should be replaced with the replica_url. + Strips the query parameters and trailing slash from a URL. """ - parsed_url = httpx.URL(replica_url) + parsed_url = httpx.URL(orig_url) stripped_url = parsed_url.copy_with(query=None) stripped_url = str(stripped_url) - if not stripped_url.endswith("/"): - stripped_url += "/" - - for component in config["components"]: - if component.get("props") is not None: - root_url = component["props"].get("root_url") - # Don't replace the root_url if it's loaded from a different Space - if root_url is None or root_url in all_replica_urls: - component["props"]["root_url"] = stripped_url + return stripped_url.rstrip("/") diff --git a/gradio/routes.py b/gradio/routes.py index 71418dd25554..9f110b957886 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -53,7 +53,7 @@ from gradio.exceptions import Error from gradio.oauth import attach_oauth from gradio.queueing import Estimation, Event -from gradio.route_utils import Request, set_replica_url_in_config # noqa: F401 +from gradio.route_utils import Request # noqa: F401 from gradio.state_holder import StateHolder from gradio.utils import ( cancel_tasks, @@ -128,9 +128,6 @@ def __init__(self, **kwargs): self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str( Path(tempfile.gettempdir()) / "gradio" ) - self.replica_urls = ( - set() - ) # these are the full paths to the replicas if running on a Hugging Face Space with multiple replicas self.change_event: None | threading.Event = None self._asyncio_tasks: list[asyncio.Task] = [] # Allow user to manually set `docs_url` and `redoc_url` @@ -166,10 +163,9 @@ def build_proxy_request(self, url_path): assert self.blocks # Don't proxy a URL unless it's a URL specifically loaded by the user using # gr.load() to prevent SSRF or harvesting of HF tokens by malicious Spaces. - safe_urls = {httpx.URL(root).host for root in self.blocks.root_urls} | { - httpx.URL(root).host for root in self.replica_urls - } - is_safe_url = url.host in safe_urls + is_safe_url = any( + url.host == httpx.URL(root).host for root in self.blocks.root_urls + ) if not is_safe_url: raise PermissionError("This URL cannot be proxied.") is_hf_url = url.host.endswith(".hf.space") @@ -313,22 +309,17 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()): @app.head("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse) - async def main(request: fastapi.Request, user: str = Depends(get_current_user)): + def main(request: fastapi.Request, user: str = Depends(get_current_user)): mimetypes.add_type("application/javascript", ".js") blocks = app.get_blocks() - root_path = request.scope.get("root_path", "") - + root_path = ( + request.scope.get("root_path") + or request.headers.get("X-Direct-Url") + or "" + ) if app.auth is None or user is not None: config = app.get_blocks().config - config["root"] = root_path - - # Handles the case where the app is running on Hugging Face Spaces with - # multiple replicas. See `set_replica_url_in_config` for more details. - replica_url = request.headers.get("X-Direct-Url") - if utils.get_space() and replica_url: - app.replica_urls.add(replica_url) - async with app.lock: - set_replica_url_in_config(config, replica_url, app.replica_urls) + config["root"] = route_utils.strip_url(root_path) else: config = { "auth_required": True, @@ -365,19 +356,14 @@ def api_info(serialize: bool = True): @app.get("/config/", dependencies=[Depends(login_check)]) @app.get("/config", dependencies=[Depends(login_check)]) - async def get_config(request: fastapi.Request): + def get_config(request: fastapi.Request): + root_path = ( + request.scope.get("root_path") + or request.headers.get("X-Direct-Url") + or "" + ) config = app.get_blocks().config - - # Handles the case where the app is running on Hugging Face Spaces with - # multiple replicas. See `set_replica_url_in_config` for more details. - replica_url = request.headers.get("X-Direct-Url") - if utils.get_space() and replica_url: - app.replica_urls.add(replica_url) - async with app.lock: - set_replica_url_in_config(config, replica_url, app.replica_urls) - - root_path = request.scope.get("root_path", "") - config["root"] = root_path + config["root"] = route_utils.strip_url(root_path) return config @app.get("/static/{path:path}") diff --git a/test/test_route_utils.py b/test/test_route_utils.py deleted file mode 100644 index a273c7d063d6..000000000000 --- a/test/test_route_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -from gradio.route_utils import set_replica_url_in_config - - -def test_set_replica_url(): - config = { - "components": [ - {"props": {}}, - {"props": {"root_url": "existing_url/"}}, - {"props": {"root_url": "different_url/"}}, - {}, - ] - } - replica_url = "https://abidlabs-test-client-replica--fttzk.hf.space?__theme=light" - - set_replica_url_in_config(config, replica_url, {"existing_url/"}) - assert ( - config["components"][0]["props"]["root_url"] - == "https://abidlabs-test-client-replica--fttzk.hf.space/" - ) - assert ( - config["components"][1]["props"]["root_url"] - == "https://abidlabs-test-client-replica--fttzk.hf.space/" - ) - assert config["components"][2]["props"]["root_url"] == "different_url/" - assert "props" not in config["components"][3] - - -def test_url_without_trailing_slash(): - config = {"components": [{"props": {}}]} - replica_url = "https://abidlabs-test-client-replica--fttzk.hf.space" - - set_replica_url_in_config(config, replica_url, set()) - assert ( - config["components"][0]["props"]["root_url"] - == "https://abidlabs-test-client-replica--fttzk.hf.space/" - )