Skip to content

Commit

Permalink
Revert replica proxy logic and instead implement using the root var…
Browse files Browse the repository at this point in the history
…iable (#5776)

* Revert "Fix for deepcopy errors when running the replica-related logic on Spaces (#5722)"

This reverts commit dba6519.

* Revert "Fully resolve generated filepaths when running on Hugging Face Spaces with multiple replicas (#5668)"

This reverts commit d626c21.

* add changeset

* Trigger local

* add changeset

* add to root

* add changeset

* strip url

* resolve root

* changes

* fix

* format

* logs

* format

* add changeset

* reverse order

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Oct 5, 2023
1 parent 59c3591 commit c0fef44
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 96 deletions.
6 changes: 6 additions & 0 deletions .changeset/odd-parrots-open.md
@@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---

fix:Revert replica proxy logic and instead implement using the `root` variable
22 changes: 15 additions & 7 deletions client/js/src/client.ts
Expand Up @@ -8,7 +8,8 @@ import {
get_space_hardware,
set_space_hardware,
set_space_timeout,
hardware_types
hardware_types,
resolve_root
} from "./utils.js";

import type {
Expand Down Expand Up @@ -436,7 +437,7 @@ export function api_factory(
}

handle_blob(
`${http_protocol}//${host + config.path}`,
`${http_protocol}//${resolve_root(host, config.path, true)}`,
data,
api_info,
hf_token
Expand All @@ -453,7 +454,7 @@ export function api_factory(
});

post_data(
`${http_protocol}//${host + config.path}/run${
`${http_protocol}//${resolve_root(host, config.path, true)}/run${
_endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
}${url_params ? "?" + url_params : ""}`,
{
Expand Down Expand Up @@ -521,8 +522,11 @@ export function api_factory(
fn_index,
time: new Date()
});

let url = new URL(`${ws_protocol}://${host}${config.path}
let url = new URL(`${ws_protocol}://${resolve_root(
host,
config.path,
true
)}
/queue/join${url_params ? "?" + url_params : ""}`);

if (jwt) {
Expand Down Expand Up @@ -686,7 +690,11 @@ export function api_factory(

try {
await fetch_implementation(
`${http_protocol}//${host + config.path}/reset`,
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/reset`,
{
headers: { "Content-Type": "application/json" },
method: "POST",
Expand Down Expand Up @@ -1206,7 +1214,7 @@ async function resolve_config(
) {
const path = window.gradio_config.root;
const config = window.gradio_config;
config.root = endpoint + config.root;
config.root = resolve_root(endpoint, config.root, false);
return { ...config, path: path };
} else if (endpoint) {
let response = await fetch_implementation(`${endpoint}/config`, {
Expand Down
21 changes: 21 additions & 0 deletions client/js/src/utils.ts
@@ -1,5 +1,26 @@
import type { Config } from "./types.js";

/**
* This function is used to resolve the URL for making requests when the app has a root path.
* The root path could be a path suffix like "/app" which is appended to the end of the base URL. Or
* it could be a full URL like "https://abidlabs-test-client-replica--gqf2x.hf.space" which is used when hosting
* Gradio apps on Hugging Face Spaces.
* @param {string} base_url The base URL at which the Gradio server is hosted
* @param {string} root_path The root path, which could be a path suffix (e.g. mounted in FastAPI app) or a full URL (e.g. hosted on Hugging Face Spaces)
* @param {boolean} prioritize_base Whether to prioritize the base URL over the root path. This is used when both the base path and root paths are full URLs. For example, for fetching files the root path should be prioritized, but for making request, the base URL should be prioritized.
* @returns {string} the resolved URL
*/
export function resolve_root(
base_url: string,
root_path: string,
prioritize_base: boolean
): string {
if (root_path.startsWith("http://") || root_path.startsWith("https://")) {
return prioritize_base ? base_url : root_path;
}
return base_url + root_path;
}

export function determine_protocol(endpoint: string): {
ws_protocol: "ws" | "wss";
http_protocol: "http:" | "https:";
Expand Down
25 changes: 4 additions & 21 deletions gradio/route_utils.py
Expand Up @@ -247,28 +247,11 @@ async def call_process_api(
return output


def set_replica_url_in_config(
config: dict, replica_url: str, all_replica_urls: set[str]
) -> None:
def strip_url(orig_url: str) -> str:
"""
If the Gradio app is running on Hugging Face Spaces and the machine has multiple replicas,
we pass in the direct URL to the replica so that we have the fully resolved path to any files
on that machine. This direct URL can be shared with other users and the path will still work.
Parameters:
config: The config dictionary to modify.
replica_url: The direct URL to the replica.
all_replica_urls: The direct URLs to the other replicas. These should be replaced with the replica_url.
Strips the query parameters and trailing slash from a URL.
"""
parsed_url = httpx.URL(replica_url)
parsed_url = httpx.URL(orig_url)
stripped_url = parsed_url.copy_with(query=None)
stripped_url = str(stripped_url)
if not stripped_url.endswith("/"):
stripped_url += "/"

for component in config["components"]:
if component.get("props") is not None:
root_url = component["props"].get("root_url")
# Don't replace the root_url if it's loaded from a different Space
if root_url is None or root_url in all_replica_urls:
component["props"]["root_url"] = stripped_url
return stripped_url.rstrip("/")
50 changes: 18 additions & 32 deletions gradio/routes.py
Expand Up @@ -53,7 +53,7 @@
from gradio.exceptions import Error
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation, Event
from gradio.route_utils import Request, set_replica_url_in_config # noqa: F401
from gradio.route_utils import Request # noqa: F401
from gradio.state_holder import StateHolder
from gradio.utils import (
cancel_tasks,
Expand Down Expand Up @@ -128,9 +128,6 @@ def __init__(self, **kwargs):
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
self.replica_urls = (
set()
) # these are the full paths to the replicas if running on a Hugging Face Space with multiple replicas
self.change_event: None | threading.Event = None
self._asyncio_tasks: list[asyncio.Task] = []
# Allow user to manually set `docs_url` and `redoc_url`
Expand Down Expand Up @@ -166,10 +163,9 @@ def build_proxy_request(self, url_path):
assert self.blocks
# Don't proxy a URL unless it's a URL specifically loaded by the user using
# gr.load() to prevent SSRF or harvesting of HF tokens by malicious Spaces.
safe_urls = {httpx.URL(root).host for root in self.blocks.root_urls} | {
httpx.URL(root).host for root in self.replica_urls
}
is_safe_url = url.host in safe_urls
is_safe_url = any(
url.host == httpx.URL(root).host for root in self.blocks.root_urls
)
if not is_safe_url:
raise PermissionError("This URL cannot be proxied.")
is_hf_url = url.host.endswith(".hf.space")
Expand Down Expand Up @@ -313,22 +309,17 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):

@app.head("/", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse)
async def main(request: fastapi.Request, user: str = Depends(get_current_user)):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
root_path = request.scope.get("root_path", "")

root_path = (
request.scope.get("root_path")
or request.headers.get("X-Direct-Url")
or ""
)
if app.auth is None or user is not None:
config = app.get_blocks().config
config["root"] = root_path

# Handles the case where the app is running on Hugging Face Spaces with
# multiple replicas. See `set_replica_url_in_config` for more details.
replica_url = request.headers.get("X-Direct-Url")
if utils.get_space() and replica_url:
app.replica_urls.add(replica_url)
async with app.lock:
set_replica_url_in_config(config, replica_url, app.replica_urls)
config["root"] = route_utils.strip_url(root_path)
else:
config = {
"auth_required": True,
Expand Down Expand Up @@ -365,19 +356,14 @@ def api_info(serialize: bool = True):

@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
async def get_config(request: fastapi.Request):
def get_config(request: fastapi.Request):
root_path = (
request.scope.get("root_path")
or request.headers.get("X-Direct-Url")
or ""
)
config = app.get_blocks().config

# Handles the case where the app is running on Hugging Face Spaces with
# multiple replicas. See `set_replica_url_in_config` for more details.
replica_url = request.headers.get("X-Direct-Url")
if utils.get_space() and replica_url:
app.replica_urls.add(replica_url)
async with app.lock:
set_replica_url_in_config(config, replica_url, app.replica_urls)

root_path = request.scope.get("root_path", "")
config["root"] = root_path
config["root"] = route_utils.strip_url(root_path)
return config

@app.get("/static/{path:path}")
Expand Down
36 changes: 0 additions & 36 deletions test/test_route_utils.py

This file was deleted.

0 comments on commit c0fef44

Please sign in to comment.