Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch>=1.7
torchvision
pyyaml
huggingface_hub
huggingface_hub>=0.17.0
safetensors>=0.2
numpy
21 changes: 7 additions & 14 deletions timm/models/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
from timm.models._pretrained import filter_pretrained_cfg

try:
from huggingface_hub import (
create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
Expand Down Expand Up @@ -414,20 +411,16 @@ def push_to_hf_hub(
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Can be set to `"both"` in order to push both safe and unsafe weights.
"""
api = HfApi(token=token, library_name="timm", library_version=__version__)

# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
repo_url = api.create_repo(repo_id, private=private, exist_ok=True)

# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"
repo_id = repo_url.repo_id

# Check if README file already exist in repo
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
has_readme = api.file_exists(repo_id=repo_id, filename="README.md", revision=revision)

# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
Expand All @@ -449,7 +442,7 @@ def push_to_hf_hub(
readme_path.write_text(readme_text)

# Upload model and return
return upload_folder(
return api.upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
Expand Down