Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@

import dataclasses
import json
import os
import re
import sys
from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union

from .utils.logging import get_logger

from sparsezoo import Zoo
from sparsezoo.requests.base import ZOO_STUB_PREFIX


logger = get_logger(__name__)


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -190,12 +199,17 @@ def parse_args_into_dataclasses(
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
return tuple(
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
remaining_args,
)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Expand All @@ -209,7 +223,9 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Expand All @@ -222,4 +238,38 @@ def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
inputs = {k: v for k, v in args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)


def _download_dataclass_zoo_stub_files(data_class: DataClass):
for name, val in data_class.__dict__.items():
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
continue

logger.info(f"Downloading framework files for SparseZoo stub: {val}")

zoo_model = Zoo.load_model_from_stub(val)
framework_file_paths = zoo_model.download_framework_files()
assert framework_file_paths, (
"Unable to download any framework files for SparseZoo stub {val}"
)
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
if "pytorch_model.bin" not in framework_file_names or (
"config.json" not in framework_file_names
):
raise RuntimeError(
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
f"files downloaded from {val}. Found {framework_file_names}. Check "
"if the given stub is for a transformers repo model"
)
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()

logger.info(
f"Overwriting argument {name} to downloaded {framework_dir_path}"
)

data_class.__dict__[name] = str(framework_dir_path)

return data_class