Skip to content

Commit

Permalink
Little cleanup: let huggingface_hub manage token retrieval (#21333)
Browse files Browse the repository at this point in the history
* Let huggingface_hub manage token retrieval

* flake8

* code quality

* adapt in every PushToHubMixin children

* add explicit return type
  • Loading branch information
Wauplin committed Jan 27, 2023
1 parent 0dff407 commit 8f3b4a1
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 55 deletions.
8 changes: 6 additions & 2 deletions src/transformers/configuration_utils.py
Expand Up @@ -438,7 +438,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
Expand All @@ -454,7 +454,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

@classmethod
Expand Down
11 changes: 2 additions & 9 deletions src/transformers/dynamic_module_utils.py
Expand Up @@ -22,7 +22,7 @@
from pathlib import Path
from typing import Dict, Optional, Union

from huggingface_hub import HfFolder, model_info
from huggingface_hub import model_info

from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging

Expand Down Expand Up @@ -251,14 +251,7 @@ def get_cached_module_file(
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
else:
token = None

commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=use_auth_token).sha

# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
# benefit of versioning.
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/feature_extraction_utils.py
Expand Up @@ -353,7 +353,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
Expand All @@ -369,7 +369,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

return [output_feature_extractor_file]
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/generation/configuration_utils.py
Expand Up @@ -337,7 +337,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

output_config_file = os.path.join(save_directory, config_file_name)
Expand All @@ -347,7 +347,11 @@ def save_pretrained(

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/image_processing_utils.py
Expand Up @@ -185,7 +185,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
Expand All @@ -201,7 +201,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

return [output_image_processor_file]
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/modeling_flax_utils.py
Expand Up @@ -1018,7 +1018,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

# get abs dir
Expand Down Expand Up @@ -1077,7 +1077,11 @@ def save_pretrained(

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

@classmethod
Expand Down
12 changes: 8 additions & 4 deletions src/transformers/modeling_tf_utils.py
Expand Up @@ -2277,7 +2277,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

if saved_model:
Expand Down Expand Up @@ -2363,7 +2363,11 @@ def save_pretrained(

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

@classmethod
Expand Down Expand Up @@ -2946,7 +2950,7 @@ def push_to_hub(
else:
working_dir = repo_id.split("/")[-1]

repo_id, token = self._create_repo(
repo_id = self._create_repo(
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
)

Expand All @@ -2968,7 +2972,7 @@ def push_to_hub(
self.create_model_card(**base_model_card_args)

self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/modeling_utils.py
Expand Up @@ -1633,7 +1633,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

# Only save the model itself if we are using distributed training
Expand Down Expand Up @@ -1717,7 +1717,11 @@ def save_pretrained(

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

def get_memory_footprint(self, return_buffers=True):
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/processing_utils.py
Expand Up @@ -121,7 +121,7 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
Expand All @@ -147,7 +147,11 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/tokenization_utils_base.py
Expand Up @@ -2098,7 +2098,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

special_tokens_map_file = os.path.join(
Expand Down Expand Up @@ -2177,7 +2177,11 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)

return save_files
Expand Down
44 changes: 18 additions & 26 deletions src/transformers/utils/hub.py
Expand Up @@ -31,7 +31,6 @@
import requests
from huggingface_hub import (
CommitOperationAdd,
HfFolder,
create_commit,
create_repo,
get_hf_file_metadata,
Expand All @@ -45,6 +44,7 @@
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
build_hf_headers,
hf_raise_for_status,
)
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -583,7 +583,7 @@ def has_file(
use_auth_token: Optional[Union[bool, str]] = None,
):
"""
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders.
Checks if a repo contains a given file without downloading it. Works for remote repos and local folders.
<Tip warning={false}>
Expand All @@ -596,15 +596,7 @@ def has_file(
return os.path.isfile(os.path.join(path_or_repo, filename))

url = hf_hub_url(path_or_repo, filename=filename, revision=revision)

headers = {"user-agent": http_user_agent()}
if isinstance(use_auth_token, str):
headers["authorization"] = f"Bearer {use_auth_token}"
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = f"Bearer {token}"
headers = build_hf_headers(use_auth_token=use_auth_token, user_agent=http_user_agent())

r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
Expand Down Expand Up @@ -636,10 +628,10 @@ def _create_repo(
use_auth_token: Optional[Union[bool, str]] = None,
repo_url: Optional[str] = None,
organization: Optional[str] = None,
):
) -> str:
"""
Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the
token.
Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
the token.
"""
if repo_url is not None:
warnings.warn(
Expand All @@ -657,13 +649,12 @@ def _create_repo(
repo_id = repo_id.split("/")[-1]
repo_id = f"{organization}/{repo_id}"

token = HfFolder.get_token() if use_auth_token is True else use_auth_token
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
url = create_repo(repo_id=repo_id, token=use_auth_token, private=private, exist_ok=True)

# If the namespace is not there, add it or `upload_file` will complain
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
repo_id = get_full_repo_name(repo_id, token=token)
return repo_id, token
repo_id = get_full_repo_name(repo_id, token=use_auth_token)
return repo_id

def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
"""
Expand All @@ -677,7 +668,7 @@ def _upload_modified_files(
repo_id: str,
files_timestamps: Dict[str, float],
commit_message: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
):
"""
Expand Down Expand Up @@ -776,7 +767,7 @@ def push_to_hub(
else:
working_dir = repo_id.split("/")[-1]

repo_id, token = self._create_repo(
repo_id = self._create_repo(
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
)

Expand All @@ -790,13 +781,16 @@ def push_to_hub(
self.save_pretrained(work_dir, max_shard_size=max_shard_size)

return self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr
work_dir,
repo_id,
files_timestamps,
commit_message=commit_message,
token=use_auth_token,
create_pr=create_pr,
)


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
Expand Down Expand Up @@ -1040,8 +1034,6 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
cache_dir = str(old_cache)
else:
cache_dir = new_cache_dir
if token is None:
token = HfFolder.get_token()
cached_files = get_all_cached_files(cache_dir=cache_dir)
logger.info(f"Moving {len(cached_files)} files to the new cache system")

Expand All @@ -1050,7 +1042,7 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
url = file_info.pop("url")
if url not in hub_metadata:
try:
hub_metadata[url] = get_hf_file_metadata(url, use_auth_token=token)
hub_metadata[url] = get_hf_file_metadata(url, token=token)
except requests.HTTPError:
continue

Expand Down

0 comments on commit 8f3b4a1

Please sign in to comment.