From 8f3b4a1d5bd97045541c43179efe8cd9c58adb76 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 27 Jan 2023 18:09:49 +0100 Subject: [PATCH] Little cleanup: let huggingface_hub manage token retrieval (#21333) * Let huggingface_hub manage token retrieval * flake8 * code quality * adapt in every PushToHubMixin children * add explicit return type --- src/transformers/configuration_utils.py | 8 +++- src/transformers/dynamic_module_utils.py | 11 +---- src/transformers/feature_extraction_utils.py | 8 +++- .../generation/configuration_utils.py | 8 +++- src/transformers/image_processing_utils.py | 8 +++- src/transformers/modeling_flax_utils.py | 8 +++- src/transformers/modeling_tf_utils.py | 12 +++-- src/transformers/modeling_utils.py | 8 +++- src/transformers/processing_utils.py | 8 +++- src/transformers/tokenization_utils_base.py | 8 +++- src/transformers/utils/hub.py | 44 ++++++++----------- 11 files changed, 76 insertions(+), 55 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index ec7c1890699c1..8d7a8dc559ce4 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -438,7 +438,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be @@ -454,7 +454,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) @classmethod diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 0c2067cf2e53d..f3fc14838275b 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -22,7 +22,7 @@ from pathlib import Path from typing import Dict, Optional, Union -from huggingface_hub import HfFolder, model_info +from huggingface_hub import model_info from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging @@ -251,14 +251,7 @@ def get_cached_module_file( else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. - if isinstance(use_auth_token, str): - token = use_auth_token - elif use_auth_token is True: - token = HfFolder.get_token() - else: - token = None - - commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha + commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=use_auth_token).sha # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # benefit of versioning. diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index ff8fa009935f6..e41ff8af54a61 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -353,7 +353,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be @@ -369,7 +369,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) return [output_feature_extractor_file] diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 5b5a7d1794dd7..a42718c5c7fda 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -337,7 +337,7 @@ def save_pretrained( if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) output_config_file = os.path.join(save_directory, config_file_name) @@ -347,7 +347,11 @@ def save_pretrained( if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) @classmethod diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index 0be7719782877..feff54a3ff587 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -185,7 +185,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be @@ -201,7 +201,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) return [output_image_processor_file] diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 7dc8f4ae48915..c501764350c29 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1018,7 +1018,7 @@ def save_pretrained( if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) # get abs dir @@ -1077,7 +1077,11 @@ def save_pretrained( if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) @classmethod diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f27373bd78c5b..42a6adea12d1f 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2277,7 +2277,7 @@ def save_pretrained( if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) if saved_model: @@ -2363,7 +2363,11 @@ def save_pretrained( if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) @classmethod @@ -2946,7 +2950,7 @@ def push_to_hub( else: working_dir = repo_id.split("/")[-1] - repo_id, token = self._create_repo( + repo_id = self._create_repo( repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization ) @@ -2968,7 +2972,7 @@ def push_to_hub( self.create_model_card(**base_model_card_args) self._upload_modified_files( - work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token + work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token ) @classmethod diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4017305842d5d..88903cbe4e580 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1633,7 +1633,7 @@ def save_pretrained( if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) # Only save the model itself if we are using distributed training @@ -1717,7 +1717,11 @@ def save_pretrained( if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) def get_memory_footprint(self, return_buffers=True): diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index d13f6d8458157..4ad814a4c1bc0 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -121,7 +121,7 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # loaded from the Hub. @@ -147,7 +147,11 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) @classmethod diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 865a7a56549f8..8f3c392d38222 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2098,7 +2098,7 @@ def save_pretrained( if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id, token = self._create_repo(repo_id, **kwargs) + repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) special_tokens_map_file = os.path.join( @@ -2177,7 +2177,11 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True): if push_to_hub: self._upload_modified_files( - save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("use_auth_token"), ) return save_files diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 73c09ae135c80..4f8ea58d9dcaf 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -31,7 +31,6 @@ import requests from huggingface_hub import ( CommitOperationAdd, - HfFolder, create_commit, create_repo, get_hf_file_metadata, @@ -45,6 +44,7 @@ LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, + build_hf_headers, hf_raise_for_status, ) from requests.exceptions import HTTPError @@ -583,7 +583,7 @@ def has_file( use_auth_token: Optional[Union[bool, str]] = None, ): """ - Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders. + Checks if a repo contains a given file without downloading it. Works for remote repos and local folders. @@ -596,15 +596,7 @@ def has_file( return os.path.isfile(os.path.join(path_or_repo, filename)) url = hf_hub_url(path_or_repo, filename=filename, revision=revision) - - headers = {"user-agent": http_user_agent()} - if isinstance(use_auth_token, str): - headers["authorization"] = f"Bearer {use_auth_token}" - elif use_auth_token: - token = HfFolder.get_token() - if token is None: - raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token, user_agent=http_user_agent()) r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10) try: @@ -636,10 +628,10 @@ def _create_repo( use_auth_token: Optional[Union[bool, str]] = None, repo_url: Optional[str] = None, organization: Optional[str] = None, - ): + ) -> str: """ - Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the - token. + Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves + the token. """ if repo_url is not None: warnings.warn( @@ -657,13 +649,12 @@ def _create_repo( repo_id = repo_id.split("/")[-1] repo_id = f"{organization}/{repo_id}" - token = HfFolder.get_token() if use_auth_token is True else use_auth_token - url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True) + url = create_repo(repo_id=repo_id, token=use_auth_token, private=private, exist_ok=True) # If the namespace is not there, add it or `upload_file` will complain if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}": - repo_id = get_full_repo_name(repo_id, token=token) - return repo_id, token + repo_id = get_full_repo_name(repo_id, token=use_auth_token) + return repo_id def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]): """ @@ -677,7 +668,7 @@ def _upload_modified_files( repo_id: str, files_timestamps: Dict[str, float], commit_message: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[bool, str]] = None, create_pr: bool = False, ): """ @@ -776,7 +767,7 @@ def push_to_hub( else: working_dir = repo_id.split("/")[-1] - repo_id, token = self._create_repo( + repo_id = self._create_repo( repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization ) @@ -790,13 +781,16 @@ def push_to_hub( self.save_pretrained(work_dir, max_shard_size=max_shard_size) return self._upload_modified_files( - work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=use_auth_token, + create_pr=create_pr, ) def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" @@ -1040,8 +1034,6 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None): cache_dir = str(old_cache) else: cache_dir = new_cache_dir - if token is None: - token = HfFolder.get_token() cached_files = get_all_cached_files(cache_dir=cache_dir) logger.info(f"Moving {len(cached_files)} files to the new cache system") @@ -1050,7 +1042,7 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None): url = file_info.pop("url") if url not in hub_metadata: try: - hub_metadata[url] = get_hf_file_metadata(url, use_auth_token=token) + hub_metadata[url] = get_hf_file_metadata(url, token=token) except requests.HTTPError: continue