diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 4326a589d65f..566c2a3976dc 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -14,6 +14,7 @@ import dataclasses import json +import os import re import sys from argparse import ArgumentParser, ArgumentTypeError @@ -21,6 +22,14 @@ from pathlib import Path from typing import Any, Iterable, List, NewType, Optional, Tuple, Union +from .utils.logging import get_logger + +from sparsezoo import Zoo +from sparsezoo.requests.base import ZOO_STUB_PREFIX + + +logger = get_logger(__name__) + DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) @@ -190,12 +199,17 @@ def parse_args_into_dataclasses( # additional namespace. outputs.append(namespace) if return_remaining_strings: - return (*outputs, remaining_args) + return tuple( + *[_download_dataclass_zoo_stub_files(output) for output in outputs], + remaining_args, + ) else: if remaining_args: raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: """ @@ -209,7 +223,9 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: inputs = {k: v for k, v in data.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: """ @@ -222,4 +238,38 @@ def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: inputs = {k: v for k, v in args.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) + + +def _download_dataclass_zoo_stub_files(data_class: DataClass): + for name, val in data_class.__dict__.items(): + if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"): + continue + + logger.info(f"Downloading framework files for SparseZoo stub: {val}") + + zoo_model = Zoo.load_model_from_stub(val) + framework_file_paths = zoo_model.download_framework_files() + assert framework_file_paths, ( + "Unable to download any framework files for SparseZoo stub {val}" + ) + framework_file_names = [os.path.basename(path) for path in framework_file_paths] + if "pytorch_model.bin" not in framework_file_names or ( + "config.json" not in framework_file_names + ): + raise RuntimeError( + "Unable to find 'pytorch_model.bin' and 'config.json' in framework " + f"files downloaded from {val}. Found {framework_file_names}. Check " + "if the given stub is for a transformers repo model" + ) + framework_dir_path = Path(framework_file_paths[0]).parent.absolute() + + logger.info( + f"Overwriting argument {name} to downloaded {framework_dir_path}" + ) + + data_class.__dict__[name] = str(framework_dir_path) + + return data_class