Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 14 additions & 25 deletions utils/create_dummy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from check_config_docstrings import get_checkpoint_from_config_class
from datasets import load_dataset
from get_test_info import get_model_to_tester_mapping, get_tester_classes_for_model
from huggingface_hub import Repository, create_repo, hf_api, upload_folder # TODO: remove Repository
from huggingface_hub import create_repo, hf_api, upload_folder

from transformers import (
CONFIG_MAPPING,
Expand Down Expand Up @@ -805,30 +805,19 @@ def upload_model(model_dir, organization, token):
if error is not None:
raise error

with tempfile.TemporaryDirectory() as tmpdir:
repo = Repository(local_dir=tmpdir, clone_from=repo_id, token=token)
repo.git_pull()
shutil.copytree(model_dir, tmpdir, dirs_exist_ok=True)

if repo_exist:
# Open a PR on the existing Hub repo.
hub_pr_url = upload_folder(
folder_path=model_dir,
repo_id=repo_id,
repo_type="model",
commit_message=f"Update tiny models for {arch_name}",
commit_description=f"Upload tiny models for {arch_name}",
create_pr=True,
token=token,
)
logger.warning(f"PR open in {hub_pr_url}.")
# TODO: We need this information?
else:
# Push to Hub repo directly
repo.git_add(auto_lfs_track=True)
repo.git_commit(f"Upload tiny models for {arch_name}")
repo.git_push(blocking=True) # this prints a progress bar with the upload
logger.warning(f"Tiny models {arch_name} pushed to {repo_id}.")
create_pr = repo_exist # Open a PR on existing repo, otherwise push directly
commit = upload_folder(
folder_path=model_dir,
repo_id=repo_id,
repo_type="model",
commit_message=f"Update tiny models for {arch_name}",
commit_description=f"Upload tiny models for {arch_name}",
create_pr=create_pr,
token=token,
)

msg = f"PR open in {commit.pr_url}." if create_pr else f"Tiny models {arch_name} pushed to {repo_id}."
logger.warning(msg)


def build_composite_models(config_class, output_dir):
Expand Down