-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ FEAT] Add the possibility to push custom tags using PreTrainedModel
itself
#28405
Changes from 28 commits
dd5cbe3
f78ed31
da00274
1180585
3485b10
4b82255
4c7806e
8b89796
e73dc7b
fbef2de
0e4daad
c19e751
eb93371
1fe93b3
a24ad9b
6cfd6f5
40a1d4b
dc31941
acd676b
db3197d
f14cf93
31117f4
36f2cb7
514f13b
b3d5900
22d3412
85584ae
1e3fc1e
59738c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,8 @@ | |
from huggingface_hub import ( | ||
_CACHED_NO_EXIST, | ||
CommitOperationAdd, | ||
ModelCard, | ||
ModelCardData, | ||
constants, | ||
create_branch, | ||
create_commit, | ||
|
@@ -762,6 +764,7 @@ def push_to_hub( | |
safe_serialization: bool = True, | ||
revision: str = None, | ||
commit_description: str = None, | ||
tags: Optional[List[str]] = None, | ||
**deprecated_kwargs, | ||
) -> str: | ||
""" | ||
|
@@ -795,6 +798,8 @@ def push_to_hub( | |
Branch to push the uploaded files to. | ||
commit_description (`str`, *optional*): | ||
The description of the commit that will be created | ||
tags (`List[str]`, *optional*): | ||
List of tags to push on the Hub. | ||
|
||
Examples: | ||
|
||
|
@@ -811,6 +816,7 @@ def push_to_hub( | |
``` | ||
""" | ||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None) | ||
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False) | ||
if use_auth_token is not None: | ||
warnings.warn( | ||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", | ||
|
@@ -855,6 +861,11 @@ def push_to_hub( | |
repo_id, private=private, token=token, repo_url=repo_url, organization=organization | ||
) | ||
|
||
# Create a new empty model card and eventually tag it | ||
model_card = create_and_tag_model_card( | ||
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors | ||
) | ||
|
||
if use_temp_dir is None: | ||
use_temp_dir = not os.path.isdir(working_dir) | ||
|
||
|
@@ -864,6 +875,9 @@ def push_to_hub( | |
# Save all files. | ||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) | ||
|
||
# Update model card if needed: | ||
model_card.save(os.path.join(work_dir, "README.md")) | ||
|
||
return self._upload_modified_files( | ||
work_dir, | ||
repo_id, | ||
|
@@ -1081,6 +1095,43 @@ def extract_info_from_url(url): | |
return {"repo": cache_repo, "revision": revision, "filename": filename} | ||
|
||
|
||
def create_and_tag_model_card( | ||
repo_id: str, | ||
tags: Optional[List[str]] = None, | ||
token: Optional[str] = None, | ||
ignore_metadata_errors: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wauplin @amyeroberts fine to assume that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes definitely, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok perfect! let's keep it like that then |
||
): | ||
""" | ||
Creates or loads an existing model card and tags it. | ||
|
||
Args: | ||
repo_id (`str`): | ||
The repo_id where to look for the model card. | ||
tags (`List[str]`, *optional*): | ||
The list of tags to add in the model card | ||
token (`str`, *optional*): | ||
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. | ||
ignore_metadata_errors (`str`): | ||
If True, errors while parsing the metadata section will be ignored. Some information might be lost during | ||
the process. Use it at your own risk. | ||
Comment on lines
+1112
to
+1116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied this from ModelCard.load_card docstring |
||
""" | ||
try: | ||
# Check if the model card is present on the remote repo | ||
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors) | ||
except EntryNotFoundError: | ||
# Otherwise create a simple model card from template | ||
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated." | ||
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers") | ||
model_card = ModelCard.from_template(card_data, model_description=model_description) | ||
|
||
if tags is not None: | ||
for model_tag in tags: | ||
if model_tag not in model_card.data.tags: | ||
model_card.data.tags.append(model_tag) | ||
|
||
return model_card | ||
|
||
|
||
def clean_files_for(file): | ||
""" | ||
Remove, if they exist, file, file.json and file.lock | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1435,6 +1435,11 @@ def tearDownClass(cls): | |
except HTTPError: | ||
pass | ||
|
||
try: | ||
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags") | ||
except HTTPError: | ||
pass | ||
|
||
@unittest.skip("This test is flaky") | ||
def test_push_to_hub(self): | ||
config = BertConfig( | ||
|
@@ -1522,6 +1527,28 @@ def test_push_to_hub_dynamic_model(self): | |
new_model = AutoModel.from_config(config, trust_remote_code=True) | ||
self.assertEqual(new_model.__class__.__name__, "CustomModel") | ||
|
||
def test_push_to_hub_with_tags(self): | ||
from huggingface_hub import ModelCard | ||
|
||
new_tags = ["tag-1", "tag-2"] | ||
|
||
CustomConfig.register_for_auto_class() | ||
CustomModel.register_for_auto_class() | ||
|
||
config = CustomConfig(hidden_size=32) | ||
model = CustomModel(config) | ||
|
||
self.assertTrue(model.model_tags is None) | ||
|
||
model.add_model_tags(new_tags) | ||
|
||
self.assertTrue(model.model_tags == new_tags) | ||
|
||
model.push_to_hub("test-dynamic-model-with-tags", token=self._token) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it problematic to have a fixed name for the repo_id? This could lead to conflicts if several CIs are ran in parallel. Using a unique id (uuid) would avoid that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm good point, we do the same things for all tests above so perhaps it is fine, I will let @ydshieh comment on this if he has more insights than me |
||
|
||
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags") | ||
self.assertEqual(loaded_model_card.data.tags, new_tags) | ||
|
||
|
||
@require_torch | ||
class AttentionMaskTester(unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This handles a corner case I just found out for Trainers, if you push into a repo that has already some tags, it will overwrite the existing tags with new tags. This block simply cirumvents this by instead of overwriting new tags it will append new tags from existing tags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if this is ok
can give
None
if the model card is emptyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look OK to me - only comment is that
existing_tags
might beNone
, so you'll need to convert it to a list before addingtags
ifModelCard.load(model_card_filepath).data.tags
doesn't have anything setThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed in 59738c6 a simple workaround, I think that way we don't need to convert it to a list as we care about that corner case only if
tags
are explicilty passed + there are some existing tags already