diff --git a/packages/hub/src/lib/file-download-info.spec.ts b/packages/hub/src/lib/file-download-info.spec.ts index bea4d6b716..bb66b6966e 100644 --- a/packages/hub/src/lib/file-download-info.spec.ts +++ b/packages/hub/src/lib/file-download-info.spec.ts @@ -13,7 +13,8 @@ describe("fileDownloadInfo", () => { }); assert.strictEqual(info?.size, 536063208); - assert.strictEqual(info?.etag, '"41a0e56472bad33498744818c8b1ef2c-64"'); + assert.strictEqual(info?.etag, '"a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"'); + assert.strictEqual(info?.commitHash, 'dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7'); assert(info?.downloadLink); }); @@ -30,6 +31,7 @@ describe("fileDownloadInfo", () => { assert.strictEqual(info?.size, 134); assert.strictEqual(info?.etag, '"9eb98c817f04b051b3bcca591bcd4e03cec88018"'); + assert.strictEqual(info?.commitHash, 'dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7'); assert(!info?.downloadLink); }); @@ -45,5 +47,22 @@ describe("fileDownloadInfo", () => { assert.strictEqual(info?.size, 28); assert.strictEqual(info?.etag, '"a661b1a138dac6dc5590367402d100765010ffd6"'); + assert.strictEqual(info?.commitHash, '1a7dd4986e3dab699c24ca19b2afd0f5e1a80f37'); + }); + + it("should fetch LFS file info without redirect", async () => { + const info = await fileDownloadInfo({ + repo: { + name: "google-bert/bert-base-uncased", // full name no redirect needed + type: "model", + }, + path: "tf_model.h5", + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + assert.strictEqual(info?.size, 536063208); + assert.strictEqual(info?.etag, '"a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"'); + assert.strictEqual(info?.commitHash, 'dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7'); + assert(info?.downloadLink); }); }); diff --git a/packages/hub/src/lib/file-download-info.ts b/packages/hub/src/lib/file-download-info.ts index 210bd11e72..bde0323942 100644 --- a/packages/hub/src/lib/file-download-info.ts +++ b/packages/hub/src/lib/file-download-info.ts @@ -4,14 +4,58 @@ import type { CredentialsParams, RepoDesignation } from "../types/public"; import { checkCredentials } from "../utils/checkCredentials"; import { toRepoId } from "../utils/toRepoId"; +const HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" +const HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag" +const HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size" + export interface FileDownloadInfoOutput { size: number; etag: string; + commitHash: string | null; /** * In case of LFS file, link to download directly from cloud provider */ downloadLink: string | null; } + +/** + * Useful when we want to follow a redirection to a renamed repository without following redirection to a CDN. + * If a Location header is `/hello` we should follow the relative direct + * However we may have full url redirect, on the same origin, we need to properly compare the origin then. + * @param params + */ +async function followSameOriginRedirect(params: { + url: string, + method: string, + headers: Record, + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; +}): Promise { + const resp = await (params.fetch ?? fetch)(params.url, { + method: params.method, + headers: params.headers, + // prevent automatic redirect + redirect: 'manual', + }); + + const location: string | null = resp.headers.get('Location'); + if(!location) return resp; + + // new URL('http://foo/bar', 'http://example.com/hello').href == http://foo/bar + // new URL('/bar', 'http://example.com/hello').href == http://example.com/bar + const nURL = new URL(location, params.url); + // ensure origin are matching + if(new URL(params.url).origin !== nURL.origin) + return resp; + + return followSameOriginRedirect({ + ...params, + url: nURL.href, + }); +} + /** * @returns null when the file doesn't exist */ @@ -47,13 +91,16 @@ export async function fileDownloadInfo( }/${encodeURIComponent(params.revision ?? "main")}/${params.path}` + (params.noContentDisposition ? "?noContentDisposition=1" : ""); - const resp = await (params.fetch ?? fetch)(url, { - method: "GET", + // + const resp = await followSameOriginRedirect({ + url: url, + method: "HEAD", headers: { ...(params.credentials && { Authorization: `Bearer ${accessToken}`, + // prevent any compression => we want to know the real size of the file + 'Accept-Encoding': 'identity', }), - Range: "bytes=0-0", }, }); @@ -61,24 +108,25 @@ export async function fileDownloadInfo( return null; } - if (!resp.ok) { + // redirect to CDN is okay not an error + if (!resp.ok && !resp.headers.get('Location')) { throw await createApiError(resp); } - const etag = resp.headers.get("ETag"); - + // We favor a custom header indicating the etag of the linked resource, and + // we fallback to the regular etag header. + const etag = resp.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) ?? resp.headers.get("ETag"); if (!etag) { throw new InvalidApiResponseFormatError("Expected ETag"); } - const contentRangeHeader = resp.headers.get("content-range"); - - if (!contentRangeHeader) { + // size is required + const contentSize = resp.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) ?? resp.headers.get("Content-Length") + if (!contentSize) { throw new InvalidApiResponseFormatError("Expected size information"); } - const [, parsedSize] = contentRangeHeader.split("/"); - const size = parseInt(parsedSize); + const size = parseInt(contentSize); if (isNaN(size)) { throw new InvalidApiResponseFormatError("Invalid file size received"); @@ -87,6 +135,8 @@ export async function fileDownloadInfo( return { etag, size, - downloadLink: new URL(resp.url).hostname !== new URL(hubUrl).hostname ? resp.url : null, + // Either from response headers (if redirected) or defaults to request url + downloadLink: resp.headers.get('Location') ?? new URL(resp.url).hostname !== new URL(hubUrl).hostname ? resp.url : null, + commitHash: resp.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT), }; }