Skip to content

Commit

Permalink
[safetensors] sharded with file path (#594)
Browse files Browse the repository at this point in the history
Fixes #588

cc: @madgetr
  • Loading branch information
Mishig committed Mar 29, 2024
1 parent b3f7846 commit c95583f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
26 changes: 26 additions & 0 deletions packages/hub/src/lib/parse-safetensors-metadata.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ describe("parseSafetensorsMetadata", () => {
const parse = await parseSafetensorsMetadata({
repo: "bert-base-uncased",
computeParametersCount: true,
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
});

assert(!parse.sharded);
Expand All @@ -29,6 +30,7 @@ describe("parseSafetensorsMetadata", () => {
const parse = await parseSafetensorsMetadata({
repo: "bigscience/bloom",
computeParametersCount: true,
revision: "053d9cd9fbe814e091294f67fcfedb3397b954bb",
});

assert(parse.sharded);
Expand All @@ -53,6 +55,7 @@ describe("parseSafetensorsMetadata", () => {
const parse = await parseSafetensorsMetadata({
repo: "roberta-base",
computeParametersCount: true,
revision: "e2da8e2f811d1448a5b465c236feacd80ffbac7b",
});

assert(!parse.sharded);
Expand All @@ -67,6 +70,7 @@ describe("parseSafetensorsMetadata", () => {
repo: "CompVis/stable-diffusion-v1-4",
computeParametersCount: true,
path: "unet/diffusion_pytorch_model.safetensors",
revision: "133a221b8aa7292a167afc5127cb63fb5005638b",
});

assert(!parse.sharded);
Expand All @@ -83,4 +87,26 @@ describe("parseSafetensorsMetadata", () => {
assert.deepStrictEqual(parse.parameterCount, { F32: 859_520_964 });
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
});

it("fetch info for sharded (with the default conventional filename) with file path", async () => {
const parse = await parseSafetensorsMetadata({
repo: "Alignment-Lab-AI/ALAI-gemma-7b",
computeParametersCount: true,
path: "7b/1/model.safetensors.index.json",
revision: "37e307261fe97bbf8b2463d61dbdd1a10daa264c",
});

assert(parse.sharded);

assert.strictEqual(Object.keys(parse.headers).length, 4);

assert.deepStrictEqual(parse.headers["model-00004-of-00004.safetensors"]["model.layers.24.mlp.up_proj.weight"], {
dtype: "BF16",
shape: [24576, 3072],
data_offsets: [301996032, 452990976],
});

assert.deepStrictEqual(parse.parameterCount, { BF16: 8_537_680_896 });
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
});
});
3 changes: 2 additions & 1 deletion packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,13 @@ async function parseShardedIndex(
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
}

const pathPrefix = path.substr(0, path.lastIndexOf("/") + 1);
const filenames = [...new Set(Object.values(index.weight_map))];
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
await promisesQueue(
filenames.map(
(filename) => async () =>
[filename, await parseSingleFile(filename, params)] satisfies [string, SafetensorsFileHeader]
[filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader]
),
PARALLEL_DOWNLOADS
)
Expand Down

0 comments on commit c95583f

Please sign in to comment.