diff --git a/.changeset/sad-feet-brush.md b/.changeset/sad-feet-brush.md new file mode 100644 index 000000000000..f7a3fa7b78f2 --- /dev/null +++ b/.changeset/sad-feet-brush.md @@ -0,0 +1,6 @@ +--- +"@gradio/client": patch +"gradio": patch +--- + +fix:Improve URL handling in JS Client diff --git a/client/js/src/constants.ts b/client/js/src/constants.ts index 65451ac7b471..0cc056dcee46 100644 --- a/client/js/src/constants.ts +++ b/client/js/src/constants.ts @@ -25,6 +25,7 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. "; export const SPACE_STATUS_ERROR_MSG = "Could not get space status. "; export const API_INFO_ERROR_MSG = "Could not get API info. "; export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. "; +export const INVALID_URL_MSG = "Invalid URL. A full URL path is required."; export const UNAUTHORIZED_MSG = "Not authorized to access this space. "; export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. "; export const MISSING_CREDENTIALS_MSG = diff --git a/client/js/src/helpers/api_info.ts b/client/js/src/helpers/api_info.ts index ee514f478998..af8b3f159ed0 100644 --- a/client/js/src/helpers/api_info.ts +++ b/client/js/src/helpers/api_info.ts @@ -1,9 +1,14 @@ import type { Status } from "../types"; -import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants"; +import { + HOST_URL, + INVALID_URL_MSG, + QUEUE_FULL_MSG, + SPACE_METADATA_ERROR_MSG +} from "../constants"; import type { ApiData, ApiInfo, Config, JsApiData } from "../types"; import { determine_protocol } from "./init_helpers"; -export const RE_SPACE_NAME = /^[^\/]*\/[^\/]*$/; +export const RE_SPACE_NAME = /^[a-zA-Z0-9_\-\.]+\/[a-zA-Z0-9_\-\.]+$/; export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/; export async function process_endpoint( @@ -20,12 +25,13 @@ export async function process_endpoint( headers.Authorization = `Bearer ${hf_token}`; } - const _app_reference = app_reference.trim(); + const _app_reference = app_reference.trim().replace(/\/$/, ""); if (RE_SPACE_NAME.test(_app_reference)) { + // app_reference is a HF space name try { const res = await fetch( - `https://huggingface.co/api/spaces/${_app_reference}/host`, + `https://huggingface.co/api/spaces/${_app_reference}/${HOST_URL}`, { headers } ); @@ -41,6 +47,7 @@ export async function process_endpoint( } if (RE_SPACE_DOMAIN.test(_app_reference)) { + // app_reference is a direct HF space domain const { ws_protocol, http_protocol, host } = determine_protocol(_app_reference); @@ -58,6 +65,18 @@ export async function process_endpoint( }; } +export const join_urls = (...urls: string[]): string => { + try { + return urls.reduce((base_url: string, part: string) => { + base_url = base_url.replace(/\/+$/, ""); + part = part.replace(/^\/+/, ""); + return new URL(part, base_url + "/").toString(); + }); + } catch (e) { + throw new Error(INVALID_URL_MSG); + } +}; + export function transform_api_info( api_info: ApiInfo, config: Config, diff --git a/client/js/src/helpers/init_helpers.ts b/client/js/src/helpers/init_helpers.ts index 2fcb9b03a861..5518573bc160 100644 --- a/client/js/src/helpers/init_helpers.ts +++ b/client/js/src/helpers/init_helpers.ts @@ -9,7 +9,7 @@ import { UNAUTHORIZED_MSG } from "../constants"; import { Client } from ".."; -import { process_endpoint } from "./api_info"; +import { join_urls, process_endpoint } from "./api_info"; /** * This function is used to resolve the URL for making requests when the app has a root path. @@ -86,7 +86,8 @@ export async function resolve_config( config.root = config_root; return { ...config, path } as Config; } else if (endpoint) { - const response = await this.fetch(`${endpoint}/${CONFIG_URL}`, { + const config_url = join_urls(endpoint, CONFIG_URL); + const response = await this.fetch(config_url, { headers, credentials: "include" }); @@ -173,7 +174,7 @@ export function determine_protocol(endpoint: string): { host: string; } { if (endpoint.startsWith("http")) { - const { protocol, host } = new URL(endpoint); + const { protocol, host, pathname } = new URL(endpoint); if (host.endsWith("hf.space")) { return { @@ -185,7 +186,7 @@ export function determine_protocol(endpoint: string): { return { ws_protocol: protocol === "https:" ? "wss" : "ws", http_protocol: protocol as "http:" | "https:", - host + host: host + (pathname !== "/" ? pathname : "") }; } else if (endpoint.startsWith("file:")) { // This case is only expected to be used for the Wasm mode (Gradio-lite), diff --git a/client/js/src/test/api_info.test.ts b/client/js/src/test/api_info.test.ts index 13fb49ad92a2..bf6c413c5ef8 100644 --- a/client/js/src/test/api_info.test.ts +++ b/client/js/src/test/api_info.test.ts @@ -1,16 +1,22 @@ -import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants"; +import { + INVALID_URL_MSG, + QUEUE_FULL_MSG, + SPACE_METADATA_ERROR_MSG +} from "../constants"; import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest"; import { handle_message, get_description, get_type, process_endpoint, + join_urls, map_data_to_params } from "../helpers/api_info"; import { initialise_server } from "./server"; import { transformed_api_info } from "./test_data"; const server = initialise_server(); +const IS_NODE = process.env.TEST_MODE === "node"; beforeAll(() => server.listen()); afterEach(() => server.resetHandlers()); @@ -453,6 +459,67 @@ describe("process_endpoint", () => { const result = await process_endpoint("hmb/hello_world"); expect(result).toEqual(expected); }); + + it("processes local server URLs correctly", async () => { + const local_url = "http://localhost:7860/gradio"; + const response_local_url = await process_endpoint(local_url); + expect(response_local_url.space_id).toBe(false); + expect(response_local_url.host).toBe("localhost:7860/gradio"); + + const local_url_2 = "http://localhost:7860/gradio/"; + const response_local_url_2 = await process_endpoint(local_url_2); + expect(response_local_url_2.space_id).toBe(false); + expect(response_local_url_2.host).toBe("localhost:7860/gradio"); + }); + + it("handles hugging face space references", async () => { + const space_id = "hmb/hello_world"; + + const response = await process_endpoint(space_id); + expect(response.space_id).toBe(space_id); + expect(response.host).toContain("hf.space"); + }); + + it("handles hugging face domain URLs", async () => { + const app_reference = "https://hmb-hello-world.hf.space/"; + const response = await process_endpoint(app_reference); + expect(response.space_id).toBe("hmb-hello-world"); + expect(response.host).toBe("hmb-hello-world.hf.space"); + }); +}); + +describe("join_urls", () => { + it("joins URLs correctly", () => { + expect(join_urls("http://localhost:7860", "/gradio")).toBe( + "http://localhost:7860/gradio" + ); + expect(join_urls("http://localhost:7860/", "/gradio")).toBe( + "http://localhost:7860/gradio" + ); + expect(join_urls("http://localhost:7860", "app/", "/gradio")).toBe( + "http://localhost:7860/app/gradio" + ); + expect(join_urls("http://localhost:7860/", "/app/", "/gradio/")).toBe( + "http://localhost:7860/app/gradio/" + ); + + expect(join_urls("http://127.0.0.1:8000/app", "/config")).toBe( + "http://127.0.0.1:8000/app/config" + ); + + expect(join_urls("http://127.0.0.1:8000/app/gradio", "/config")).toBe( + "http://127.0.0.1:8000/app/gradio/config" + ); + }); + it("throws an error when the URLs are not valid", () => { + expect(() => join_urls("localhost:7860", "/gradio")).toThrowError( + INVALID_URL_MSG + ); + + expect(() => join_urls("localhost:7860", "/gradio", "app")).toThrowError( + INVALID_URL_MSG + ); + }); }); describe("map_data_params", () => { diff --git a/client/js/src/utils/view_api.ts b/client/js/src/utils/view_api.ts index 05e0c60412d8..bb65bc2d428d 100644 --- a/client/js/src/utils/view_api.ts +++ b/client/js/src/utils/view_api.ts @@ -3,7 +3,7 @@ import semiver from "semiver"; import { API_INFO_URL, BROKEN_CONNECTION_MSG } from "../constants"; import { Client } from "../client"; import { SPACE_FETCHER_URL } from "../constants"; -import { transform_api_info } from "../helpers/api_info"; +import { join_urls, transform_api_info } from "../helpers/api_info"; export async function view_api(this: Client): Promise { if (this.api_info) return this.api_info; @@ -38,7 +38,8 @@ export async function view_api(this: Client): Promise { credentials: "include" }); } else { - response = await this.fetch(`${config?.root}/${API_INFO_URL}`, { + const url = join_urls(config.root, API_INFO_URL); + response = await this.fetch(url, { headers, credentials: "include" });