From 44e65f55ff43819e04be52680b243769953fa558 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Oct 2021 14:55:57 -0400 Subject: [PATCH 1/2] support to download and unwrap framework files from SparseZoo stubs in hf_argparser --- src/transformers/hf_argparser.py | 48 +++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 4326a589d65f..b72cef133ac9 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -21,6 +21,14 @@ from pathlib import Path from typing import Any, Iterable, List, NewType, Optional, Tuple, Union +from .utils.logging import get_logger + +from sparsezoo import Zoo +from sparsezoo.requests.base import ZOO_STUB_PREFIX + + +logger = get_logger(__name__) + DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) @@ -190,12 +198,17 @@ def parse_args_into_dataclasses( # additional namespace. outputs.append(namespace) if return_remaining_strings: - return (*outputs, remaining_args) + return tuple( + *[_download_dataclass_zoo_stub_files(output) for output in outputs], + remaining_args, + ) else: if remaining_args: raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: """ @@ -209,7 +222,9 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: inputs = {k: v for k, v in data.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: """ @@ -222,4 +237,29 @@ def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: inputs = {k: v for k, v in args.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) + + +def _download_dataclass_zoo_stub_files(data_class: DataClass): + for name, val in data_class.__dict__.items(): + if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"): + continue + + logger.info(f"Downloading framework files for SparseZoo stub: {val}") + + zoo_model = Zoo.load_model_from_stub(val) + framework_file_paths = zoo_model.download_framework_files() + assert framework_file_paths, ( + "Unable to download any framework files for SparseZoo stub {val}" + ) + framework_dir_path = Path(framework_file_paths[0]).parent.absolute() + + logger.info( + f"Overwriting argument {name} to downloaded {framework_dir_path}" + ) + + data_class.__dict__[name] = str(framework_dir_path) + + return data_class From 8a7d105cb0d858f39b2441b0e7d0a5361be0ebed Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 15 Oct 2021 12:53:57 -0400 Subject: [PATCH 2/2] added check that model config and checkpoint are downloaded from zoo --- src/transformers/hf_argparser.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index b72cef133ac9..566c2a3976dc 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -14,6 +14,7 @@ import dataclasses import json +import os import re import sys from argparse import ArgumentParser, ArgumentTypeError @@ -254,6 +255,15 @@ def _download_dataclass_zoo_stub_files(data_class: DataClass): assert framework_file_paths, ( "Unable to download any framework files for SparseZoo stub {val}" ) + framework_file_names = [os.path.basename(path) for path in framework_file_paths] + if "pytorch_model.bin" not in framework_file_names or ( + "config.json" not in framework_file_names + ): + raise RuntimeError( + "Unable to find 'pytorch_model.bin' and 'config.json' in framework " + f"files downloaded from {val}. Found {framework_file_names}. Check " + "if the given stub is for a transformers repo model" + ) framework_dir_path = Path(framework_file_paths[0]).parent.absolute() logger.info(