-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ FEAT] Add the possibility to push custom tags using PreTrainedModel
itself
#28405
Changes from 24 commits
dd5cbe3
f78ed31
da00274
1180585
3485b10
4b82255
4c7806e
8b89796
e73dc7b
fbef2de
0e4daad
c19e751
eb93371
1fe93b3
a24ad9b
6cfd6f5
40a1d4b
dc31941
acd676b
db3197d
f14cf93
31117f4
36f2cb7
514f13b
b3d5900
22d3412
85584ae
1e3fc1e
59738c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,8 @@ | |
from huggingface_hub import ( | ||
_CACHED_NO_EXIST, | ||
CommitOperationAdd, | ||
ModelCard, | ||
ModelCardData, | ||
constants, | ||
create_branch, | ||
create_commit, | ||
|
@@ -762,6 +764,7 @@ def push_to_hub( | |
safe_serialization: bool = True, | ||
revision: str = None, | ||
commit_description: str = None, | ||
tags: Optional[List[str]] = None, | ||
**deprecated_kwargs, | ||
) -> str: | ||
""" | ||
|
@@ -795,6 +798,8 @@ def push_to_hub( | |
Branch to push the uploaded files to. | ||
commit_description (`str`, *optional*): | ||
The description of the commit that will be created | ||
tags (`List[str]`, *optional*): | ||
List of tags to push on the Hub. | ||
|
||
Examples: | ||
|
||
|
@@ -811,6 +816,7 @@ def push_to_hub( | |
``` | ||
""" | ||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None) | ||
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False) | ||
if use_auth_token is not None: | ||
warnings.warn( | ||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", | ||
|
@@ -855,6 +861,11 @@ def push_to_hub( | |
repo_id, private=private, token=token, repo_url=repo_url, organization=organization | ||
) | ||
|
||
# Create a new empty model card and eventually tag it | ||
model_card = create_and_tag_model_card( | ||
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors | ||
) | ||
|
||
if use_temp_dir is None: | ||
use_temp_dir = not os.path.isdir(working_dir) | ||
|
||
|
@@ -864,6 +875,9 @@ def push_to_hub( | |
# Save all files. | ||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) | ||
|
||
# Update model card if needed: | ||
model_card.save(os.path.join(work_dir, "README.md")) | ||
|
||
return self._upload_modified_files( | ||
work_dir, | ||
repo_id, | ||
|
@@ -1081,6 +1095,43 @@ def extract_info_from_url(url): | |
return {"repo": cache_repo, "revision": revision, "filename": filename} | ||
|
||
|
||
def create_and_tag_model_card( | ||
repo_id: str, | ||
tags: Optional[List[str]] = None, | ||
token: Optional[str] = None, | ||
ignore_metadata_errors: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wauplin @amyeroberts fine to assume that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes definitely, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok perfect! let's keep it like that then |
||
): | ||
""" | ||
Creates or loads an existing model card and tags it. | ||
|
||
Args: | ||
repo_id (`str`): | ||
The repo_id where to look for the model card. | ||
tags (`List[str]`, *optional*): | ||
The list of tags to add in the model card | ||
token (`str`, *optional*): | ||
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. | ||
ignore_metadata_errors (`str`): | ||
If True, errors while parsing the metadata section will be ignored. Some information might be lost during | ||
the process. Use it at your own risk. | ||
Comment on lines
+1112
to
+1116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied this from ModelCard.load_card docstring |
||
""" | ||
try: | ||
# Check if the model card is present on the remote repo | ||
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors) | ||
except EntryNotFoundError: | ||
# Otherwise create a simple model card from template | ||
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated." | ||
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers") | ||
model_card = ModelCard.from_template(card_data, model_description=model_description) | ||
|
||
if tags is not None: | ||
for model_tag in tags: | ||
if model_tag not in model_card.data.tags: | ||
model_card.data.tags.append(model_tag) | ||
|
||
return model_card | ||
|
||
|
||
def clean_files_for(file): | ||
""" | ||
Remove, if they exist, file, file.json and file.lock | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1428,6 +1428,11 @@ def tearDownClass(cls): | |
except HTTPError: | ||
pass | ||
|
||
try: | ||
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags") | ||
except HTTPError: | ||
pass | ||
|
||
@unittest.skip("This test is flaky") | ||
def test_push_to_hub(self): | ||
config = BertConfig( | ||
|
@@ -1515,6 +1520,28 @@ def test_push_to_hub_dynamic_model(self): | |
new_model = AutoModel.from_config(config, trust_remote_code=True) | ||
self.assertEqual(new_model.__class__.__name__, "CustomModel") | ||
|
||
def test_push_to_hub_with_tags(self): | ||
from huggingface_hub import ModelCard | ||
|
||
new_tags = ["tag-1", "tag-2"] | ||
|
||
CustomConfig.register_for_auto_class() | ||
CustomModel.register_for_auto_class() | ||
|
||
config = CustomConfig(hidden_size=32) | ||
model = CustomModel(config) | ||
|
||
self.assertTrue(model.model_tags is None) | ||
|
||
model.add_model_tags(new_tags) | ||
|
||
self.assertTrue(model.model_tags == new_tags) | ||
|
||
model.push_to_hub("test-dynamic-model-with-tags", token=self._token) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it problematic to have a fixed name for the repo_id? This could lead to conflicts if several CIs are ran in parallel. Using a unique id (uuid) would avoid that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm good point, we do the same things for all tests above so perhaps it is fine, I will let @ydshieh comment on this if he has more insights than me |
||
|
||
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags") | ||
self.assertEqual(loaded_model_card.data.tags, new_tags) | ||
|
||
|
||
@require_torch | ||
class AttentionMaskTester(unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current if/lse logic means:
self.model_tags
isNone
andtags
aren't passed inkwargs['tags']
is set toNone
. Do we want this or notags
kwarg at all?*- just realised I'm wrong as kwargs are passed otherwise 🙃tags
are only used ifself.model_tags
is notNone
.Another Q
kwargs["tags"]
is a string here too?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your suggestion sounds great, also
It can be strings, let me adapt a bit the logic after accepting your suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just took care of the str case in 1e3fc1e